diff --git a/src/lox/AstPrinter.java b/src/lox/AstPrinter.java index 14d15e5..bfd8eb2 100644 --- a/src/lox/AstPrinter.java +++ b/src/lox/AstPrinter.java @@ -7,23 +7,28 @@ public class AstPrinter implements Expr.Visitor { @Override public String visitAssignExpr(Expr.Assign expr) { - return parethesize(expr.name.lexeme, expr.value); + return parenthesize(expr.name.lexeme, expr.value); } @Override public String visitBinaryExpr(Expr.Binary expr) { - return parethesize(expr.operator.lexeme, + return parenthesize(expr.operator.lexeme, expr.left, expr.right); } @Override public String visitCallExpr(Expr.Call expr) { - return parethesize(expr.paren.lexeme, expr.callee); + return parenthesize(expr.paren.lexeme, expr.callee); + } + + @Override + public String visitGetExpr(Expr.Get expr) { + return parenthesize(expr.name.lexeme, expr.object); } @Override public String visitGroupingExpr(Expr.Grouping expr) { - return parethesize("group", expr.expression); + return parenthesize("group", expr.expression); } @Override @@ -34,20 +39,31 @@ public class AstPrinter implements Expr.Visitor { @Override public String visitLogicalExpr(Expr.Logical expr) { - return parethesize(expr.operator.lexeme, expr.left, expr.right); + return parenthesize(expr.operator.lexeme, expr.left, expr.right); } + @Override + public String visitSetExpr(Expr.Set expr) { + return parenthesize(expr.name.lexeme, expr.object, expr.value); + } + + @Override + public String visitThisExpr(Expr.This expr) { + return parenthesize(expr.keyword.lexeme); + } + // https://craftinginterpreters.com/classes.html#invalid-uses-of-this + @Override public String visitUnaryExpr(Expr.Unary expr) { - return parethesize(expr.operator.lexeme, expr.right); + return parenthesize(expr.operator.lexeme, expr.right); } @Override public String visitVariableExpr(Expr.Variable expr) { - return parethesize(expr.name.lexeme); + return parenthesize(expr.name.lexeme); } - private String parethesize(String name, Expr... exprs) { + private String parenthesize(String name, Expr... exprs) { StringBuilder builder = new StringBuilder(); builder.append("(").append(name); diff --git a/src/lox/Expr.java b/src/lox/Expr.java index 5de709e..c4d7a0f 100644 --- a/src/lox/Expr.java +++ b/src/lox/Expr.java @@ -7,9 +7,12 @@ abstract class Expr { R visitAssignExpr(Assign expr); R visitBinaryExpr(Binary expr); R visitCallExpr(Call expr); + R visitGetExpr(Get expr); R visitGroupingExpr(Grouping expr); R visitLiteralExpr(Literal expr); R visitLogicalExpr(Logical expr); + R visitSetExpr(Set expr); + R visitThisExpr(This expr); R visitUnaryExpr(Unary expr); R visitVariableExpr(Variable expr); } @@ -59,6 +62,20 @@ abstract class Expr { final Token paren; final List arguments; } + static class Get extends Expr { + Get(Expr object, Token name) { + this.object = object; + this.name = name; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitGetExpr(this); + } + + final Expr object; + final Token name; + } static class Grouping extends Expr { Grouping(Expr expression) { this.expression = expression; @@ -99,6 +116,34 @@ abstract class Expr { final Token operator; final Expr right; } + static class Set extends Expr { + Set(Expr object, Token name, Expr value) { + this.object = object; + this.name = name; + this.value = value; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitSetExpr(this); + } + + final Expr object; + final Token name; + final Expr value; + } + static class This extends Expr { + This(Token keyword) { + this.keyword = keyword; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitThisExpr(this); + } + + final Token keyword; + } static class Unary extends Expr { Unary(Token operator, Expr right) { this.operator = operator; diff --git a/src/lox/Interpreter.java b/src/lox/Interpreter.java index 945a30f..9a93f68 100644 --- a/src/lox/Interpreter.java +++ b/src/lox/Interpreter.java @@ -124,6 +124,17 @@ public class Interpreter implements Expr.Visitor, return function.call(this, arguments); } + @Override + public Object visitGetExpr(Expr.Get expr) { + Object object = evaluate(expr.object); + if (object instanceof LoxInstance) { + return ((LoxInstance) object).get(expr.name); + } + + throw new RuntimeError(expr.name, + "Only instances have properties."); + } + @Override public Object visitGroupingExpr(Expr.Grouping expr) { return evaluate(expr.expression); @@ -147,6 +158,25 @@ public class Interpreter implements Expr.Visitor, return evaluate(expr.right); } + @Override + public Object visitSetExpr(Expr.Set expr) { + Object object = evaluate(expr.object); + + if (!(object instanceof LoxInstance)) { + throw new RuntimeError(expr.name, + "Only instances have fields."); + } + + Object value = evaluate(expr.value); + ((LoxInstance)object).set(expr.name, value); + return value; + } + + @Override + public Object visitThisExpr(Expr.This expr) { + return lookUpVariable(expr.keyword, expr); + } + @Override public Object visitUnaryExpr(Expr.Unary expr) { Object right = evaluate(expr.right); @@ -246,6 +276,22 @@ public class Interpreter implements Expr.Visitor, return null; } + @Override + public Void visitClassStmt(Stmt.Class stmt) { + environment.define(stmt.name.lexeme, null); + + Map methods = new HashMap<>(); + for (Stmt.Function method : stmt.methods) { + LoxFunction function = new LoxFunction(method, environment, + method.name.lexeme.equals("init")); + methods.put(method.name.lexeme, function); + } + + LoxClass klass = new LoxClass(stmt.name.lexeme, methods); + environment.assign(stmt.name, klass); + return null; + } + @Override public Void visitExpressionStmt(Stmt.Expression stmt) { evaluate(stmt.expression); @@ -254,7 +300,7 @@ public class Interpreter implements Expr.Visitor, @Override public Void visitFunctionStmt(Stmt.Function stmt) { - LoxFunction function = new LoxFunction(stmt, environment); + LoxFunction function = new LoxFunction(stmt, environment, false); environment.define(stmt.name.lexeme, function); return null; } diff --git a/src/lox/LoxClass.java b/src/lox/LoxClass.java new file mode 100644 index 0000000..297f5ba --- /dev/null +++ b/src/lox/LoxClass.java @@ -0,0 +1,44 @@ +package lox; + +import java.util.List; +import java.util.Map; + +public class LoxClass implements LoxCallable { + final String name; + private final Map methods; + + public LoxClass(String name, Map methods) { + this.name = name; + this.methods = methods; + } + + LoxFunction findMethod(String name) { + if (methods.containsKey(name)) { + return methods.get(name); + } + + return null; + } + + @Override + public String toString() { + return name; + } + + @Override + public int arity() { + LoxFunction initializer = findMethod("init"); + if (initializer == null) return 0; + return initializer.arity(); + } + + @Override + public Object call(Interpreter interpreter, List arguments) { + LoxInstance instance = new LoxInstance(this); + LoxFunction initializer = findMethod("init"); + if (initializer != null) { + initializer.bind(instance).call(interpreter, arguments); + } + return instance; + } +} diff --git a/src/lox/LoxFunction.java b/src/lox/LoxFunction.java index 9c4ce77..c8fd415 100644 --- a/src/lox/LoxFunction.java +++ b/src/lox/LoxFunction.java @@ -6,9 +6,19 @@ public class LoxFunction implements LoxCallable { private final Stmt.Function declaration; private final Environment closure; - LoxFunction(Stmt.Function declaration, Environment closure) { + private final boolean isInitializer; + + LoxFunction(Stmt.Function declaration, Environment closure, boolean isInitializer) { this.closure = closure; this.declaration = declaration; + this.isInitializer = isInitializer; + } + + LoxFunction bind(LoxInstance instance) { + Environment environment = new Environment(closure); + environment.define("this", instance); + return new LoxFunction(declaration, environment, + isInitializer); } @Override @@ -27,8 +37,12 @@ public class LoxFunction implements LoxCallable { try { interpreter.executeBlock(declaration.body, environment); } catch (Return returnValue) { + if (isInitializer) return closure.getAt(0, "this"); + return returnValue.value; } + + if (isInitializer) return closure.getAt(0, "this"); return null; } diff --git a/src/lox/LoxInstance.java b/src/lox/LoxInstance.java new file mode 100644 index 0000000..c0370ae --- /dev/null +++ b/src/lox/LoxInstance.java @@ -0,0 +1,34 @@ +package lox; + +import java.util.HashMap; +import java.util.Map; + +public class LoxInstance { + private LoxClass klass; + private final Map fields = new HashMap<>(); + + LoxInstance(LoxClass klass) { + this.klass = klass; + } + + Object get(Token name) { + if (fields.containsKey(name.lexeme)) { + return fields.get(name.lexeme); + } + + LoxFunction method = klass.findMethod(name.lexeme); + if (method != null) return method.bind(this); + + throw new RuntimeError(name, + "Undefined property '" + name.lexeme + "'."); + } + + void set(Token name, Object value) { + fields.put(name.lexeme, value); + } + + @Override + public String toString() { + return klass.name + " instance"; + } +} diff --git a/src/lox/Parser.java b/src/lox/Parser.java index f5de1d0..79408e6 100644 --- a/src/lox/Parser.java +++ b/src/lox/Parser.java @@ -32,6 +32,7 @@ public class Parser { private Stmt declaration() { try { + if (match(CLASS)) return classDeclaration(); if (match(FUN)) return function("function"); if (match(VAR)) return varDeclaration(); @@ -42,6 +43,20 @@ public class Parser { } } + private Stmt classDeclaration() { + Token name = consume(IDENTIFIER, "Expect class name."); + consume(LEFT_BRACE, "Expect '{' before class body."); + + List methods = new ArrayList<>(); + while (!check(RIGHT_BRACE) && !isAtEnd()) { + methods.add(function("method")); + } + + consume(RIGHT_BRACE, "Expect '}' after class body."); + + return new Stmt.Class(name, methods); + } + private Stmt statement() { if (match(FOR)) return forStatement(); if (match(IF)) return ifStatement(); @@ -195,6 +210,9 @@ public class Parser { if (expr instanceof Expr.Variable) { Token name = ((Expr.Variable) expr).name; return new Expr.Assign(name, value); + } else if (expr instanceof Expr.Get) { + Expr.Get get = (Expr.Get)expr; + return new Expr.Set(get.object, get.name, value); } throw error(equals, "Invalid assignment target."); @@ -308,6 +326,10 @@ public class Parser { while (true) { if (match(LEFT_PAREN)) { expr = finishCall(expr); + } else if (match(DOT)) { + Token name = consume(IDENTIFIER, + "Expect property name after '.'."); + expr = new Expr.Get(expr, name); } else { break; } @@ -325,6 +347,8 @@ public class Parser { return new Expr.Literal(previous().literal); } + if (match(THIS)) return new Expr.This(previous()); + if (match(IDENTIFIER)) { return new Expr.Variable(previous()); } diff --git a/src/lox/Resolver.java b/src/lox/Resolver.java index e84b03e..8d49eff 100644 --- a/src/lox/Resolver.java +++ b/src/lox/Resolver.java @@ -16,9 +16,18 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { private enum FunctionType { NONE, - FUNCTION + FUNCTION, + INITIALIZER, + METHOD } + private enum ClassType { + NONE, + CLASS + } + + private ClassType currentClass = ClassType.NONE; + void resolve(List statements) { for (Stmt statement : statements) { resolve(statement); @@ -50,6 +59,12 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitGetExpr(Expr.Get expr) { + resolve(expr.object); + return null; + } + @Override public Void visitGroupingExpr(Expr.Grouping expr) { resolve(expr.expression); @@ -68,6 +83,25 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitSetExpr(Expr.Set expr) { + resolve(expr.value); + resolve(expr.object); + return null; + } + + @Override + public Void visitThisExpr(Expr.This expr) { + if (currentClass == ClassType.NONE) { + Lox.error(expr.keyword, + "Can't use 'this' outside of a class."); + return null; + } + + resolveLocal(expr, expr.keyword); + return null; + } + @Override public Void visitUnaryExpr(Expr.Unary expr) { resolve(expr.right); @@ -94,6 +128,31 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitClassStmt(Stmt.Class stmt) { + ClassType enclosingClass = currentClass; + currentClass = ClassType.CLASS; + + declare(stmt.name); + define(stmt.name); + + beginScope(); + scopes.peek().put("this", true); + + for (Stmt.Function method : stmt.methods) { + FunctionType declaration = FunctionType.METHOD; + if (method.name.lexeme.equals("init")) { + declaration = FunctionType.INITIALIZER; + } + resolveFunction(method, declaration); + } + + endScope(); + + currentClass = enclosingClass; + return null; + } + private void resolve(Stmt stmt) { stmt.accept(this); } @@ -185,6 +244,10 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { } if (stmt.value != null) { + if (currentFunction == FunctionType.INITIALIZER) { + Lox.error(stmt.keyword, + "Can't return a value from an initializer."); + } resolve(stmt.value); } diff --git a/src/lox/Stmt.java b/src/lox/Stmt.java index dbc11fc..9d96fb2 100644 --- a/src/lox/Stmt.java +++ b/src/lox/Stmt.java @@ -5,6 +5,7 @@ import java.util.List; abstract class Stmt { interface Visitor { R visitBlockStmt(Block stmt); + R visitClassStmt(Class stmt); R visitExpressionStmt(Expression stmt); R visitFunctionStmt(Function stmt); R visitIfStmt(If stmt); @@ -25,6 +26,20 @@ abstract class Stmt { final List statements; } + static class Class extends Stmt { + Class(Token name, List methods) { + this.name = name; + this.methods = methods; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitClassStmt(this); + } + + final Token name; + final List methods; + } static class Expression extends Stmt { Expression(Expr expression) { this.expression = expression; diff --git a/src/tool/GenerateAst.java b/src/tool/GenerateAst.java index eb853ac..65e3089 100644 --- a/src/tool/GenerateAst.java +++ b/src/tool/GenerateAst.java @@ -17,15 +17,19 @@ public class GenerateAst { "Assign : Token name, Expr value", "Binary : Expr left, Token operator, Expr right", "Call : Expr callee, Token paren, List arguments", + "Get : Expr object, Token name", "Grouping : Expr expression", "Literal : Object value", "Logical : Expr left, Token operator, Expr right", + "Set : Expr object, Token name, Expr value", + "This : Token keyword", "Unary : Token operator, Expr right", "Variable : Token name" )); defineAst(outputDir, "Stmt", Arrays.asList( "Block : List statements", + "Class : Token name, List methods", "Expression : Expr expression", "Function : Token name, List params," + " List body",