diff --git a/src/lox/AstPrinter.java b/src/lox/AstPrinter.java index d22c950..14d15e5 100644 --- a/src/lox/AstPrinter.java +++ b/src/lox/AstPrinter.java @@ -16,6 +16,11 @@ public class AstPrinter implements Expr.Visitor { expr.left, expr.right); } + @Override + public String visitCallExpr(Expr.Call expr) { + return parethesize(expr.paren.lexeme, expr.callee); + } + @Override public String visitGroupingExpr(Expr.Grouping expr) { return parethesize("group", expr.expression); diff --git a/src/lox/Expr.java b/src/lox/Expr.java index fb1aedb..5de709e 100644 --- a/src/lox/Expr.java +++ b/src/lox/Expr.java @@ -6,6 +6,7 @@ abstract class Expr { interface Visitor { R visitAssignExpr(Assign expr); R visitBinaryExpr(Binary expr); + R visitCallExpr(Call expr); R visitGroupingExpr(Grouping expr); R visitLiteralExpr(Literal expr); R visitLogicalExpr(Logical expr); @@ -42,6 +43,22 @@ abstract class Expr { final Token operator; final Expr right; } + static class Call extends Expr { + Call(Expr callee, Token paren, List arguments) { + this.callee = callee; + this.paren = paren; + this.arguments = arguments; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitCallExpr(this); + } + + final Expr callee; + final Token paren; + final List arguments; + } static class Grouping extends Expr { Grouping(Expr expression) { this.expression = expression; diff --git a/src/lox/Interpreter.java b/src/lox/Interpreter.java index 304a600..85bc28e 100644 --- a/src/lox/Interpreter.java +++ b/src/lox/Interpreter.java @@ -1,10 +1,31 @@ package lox; +import java.util.ArrayList; import java.util.List; public class Interpreter implements Expr.Visitor, Stmt.Visitor { - private Environment environment = new Environment(); + final Environment globals = new Environment(); + private Environment environment = globals; + + Interpreter() { + globals.define("clock", new LoxCallable() { + @Override + public int arity() { + return 0; + } + + @Override + public Object call(Interpreter interpreter, List arguments) { + return (double) System.currentTimeMillis() / 1000.0; + } + + @Override + public String toString() { + return ""; + } + }); + } void interpret(List statements) { try { @@ -69,6 +90,30 @@ public class Interpreter implements Expr.Visitor, } } + @Override + public Object visitCallExpr(Expr.Call expr) { + Object callee = evaluate(expr.callee); + + List arguments = new ArrayList<>(); + for (Expr argument : expr.arguments) { + arguments.add(evaluate(argument)); + } + + if (!(callee instanceof LoxCallable)) { + throw new RuntimeError(expr.paren, + "Can only call functions and classes."); + } + + LoxCallable function = (LoxCallable) callee; + if (arguments.size() != function.arity()) { + throw new RuntimeError(expr.paren, "Expected " + + function.arity() + " arguments but got " + + arguments.size() + "."); + } + + return function.call(this, arguments); + } + @Override public Object visitGroupingExpr(Expr.Grouping expr) { return evaluate(expr.expression); @@ -184,6 +229,13 @@ public class Interpreter implements Expr.Visitor, return null; } + @Override + public Void visitFunctionStmt(Stmt.Function stmt) { + LoxFunction function = new LoxFunction(stmt, environment); + environment.define(stmt.name.lexeme, function); + return null; + } + @Override public Void visitIfStmt(Stmt.If stmt) { if (isTruthy(evaluate(stmt.condition))) { @@ -201,6 +253,14 @@ public class Interpreter implements Expr.Visitor, return null; } + @Override + public Void visitReturnStmt(Stmt.Return stmt) { + Object value = null; + if (stmt.value != null) value = evaluate(stmt.value); + + throw new Return(value); + } + @Override public Void visitVarStmt(Stmt.Var stmt) { Object value = null; diff --git a/src/lox/LoxCallable.java b/src/lox/LoxCallable.java new file mode 100644 index 0000000..aee3a3d --- /dev/null +++ b/src/lox/LoxCallable.java @@ -0,0 +1,9 @@ +package lox; + +import java.util.List; + +interface LoxCallable { + int arity(); + + Object call(Interpreter interpreter, List arguments); +} diff --git a/src/lox/LoxFunction.java b/src/lox/LoxFunction.java new file mode 100644 index 0000000..9c4ce77 --- /dev/null +++ b/src/lox/LoxFunction.java @@ -0,0 +1,39 @@ +package lox; + +import java.util.List; + +public class LoxFunction implements LoxCallable { + private final Stmt.Function declaration; + private final Environment closure; + + LoxFunction(Stmt.Function declaration, Environment closure) { + this.closure = closure; + this.declaration = declaration; + } + + @Override + public int arity() { + return declaration.params.size(); + } + + @Override + public Object call(Interpreter interpreter, List arguments) { + Environment environment = new Environment(closure); + for (int i = 0; i < declaration.params.size(); i++) { + environment.define(declaration.params.get(i).lexeme, + arguments.get(i)); + } + + try { + interpreter.executeBlock(declaration.body, environment); + } catch (Return returnValue) { + return returnValue.value; + } + return null; + } + + @Override + public String toString() { + return ""; + } +} diff --git a/src/lox/Parser.java b/src/lox/Parser.java index 9fc58ea..f5de1d0 100644 --- a/src/lox/Parser.java +++ b/src/lox/Parser.java @@ -32,6 +32,7 @@ public class Parser { private Stmt declaration() { try { + if (match(FUN)) return function("function"); if (match(VAR)) return varDeclaration(); return statement(); @@ -45,6 +46,7 @@ public class Parser { if (match(FOR)) return forStatement(); if (match(IF)) return ifStatement(); if (match(PRINT)) return printStatement(); + if (match(RETURN)) return returnStatement(); if (match(WHILE)) return whileStatement(); if (match(LEFT_BRACE)) return new Stmt.Block(block()); @@ -113,6 +115,17 @@ public class Parser { return new Stmt.Print(value); } + private Stmt returnStatement() { + Token keyword = previous(); + Expr value = null; + if (!check(SEMICOLON)) { + value = expression(); + } + + consume(SEMICOLON, "Expect ';' after return value."); + return new Stmt.Return(keyword, value); + } + private Stmt varDeclaration() { Token name = consume(IDENTIFIER, "Expect variable name."); @@ -140,6 +153,27 @@ public class Parser { return new Stmt.Expression(expr); } + private Stmt.Function function(String kind) { + Token name = consume(IDENTIFIER, "Expect " + kind + " name."); + consume(LEFT_PAREN, "Expect '(' after " + kind + " name."); + List parameters = new ArrayList<>(); + if (!check(RIGHT_PAREN)) { + do { + if (parameters.size() >= 255) { + error(peek(), "Can't have more than 255 parameters."); + } + + parameters.add( + consume(IDENTIFIER, "Expect parameter name.")); + } while (match(COMMA)); + } + consume(RIGHT_PAREN, "Expect ')' after parameters."); + + consume(LEFT_BRACE, "Expect '{' before " + kind + " body."); + List body = block(); + return new Stmt.Function(name, parameters, body); + } + private List block() { List statements = new ArrayList<>(); @@ -248,7 +282,38 @@ public class Parser { return new Expr.Unary(operator, right); } - return primary(); + return call(); + } + + private Expr finishCall(Expr callee) { + List arguments = new ArrayList<>(); + if (!check(RIGHT_PAREN)) { + do { + if (arguments.size() >= 255) { + error(peek(), "Can't have more than 255 arguments."); + } + arguments.add(expression()); + } while (match(COMMA)); + } + + Token paren = consume(RIGHT_PAREN, + "Expect ')' after arguments."); + + return new Expr.Call(callee, paren, arguments); + } + + private Expr call() { + Expr expr = primary(); + + while (true) { + if (match(LEFT_PAREN)) { + expr = finishCall(expr); + } else { + break; + } + } + + return expr; } private Expr primary() { diff --git a/src/lox/Return.java b/src/lox/Return.java new file mode 100644 index 0000000..e5570a0 --- /dev/null +++ b/src/lox/Return.java @@ -0,0 +1,10 @@ +package lox; + +public class Return extends RuntimeException { + final Object value; + + Return(Object value) { + super(null, null, false, false); + this.value = value; + } +} diff --git a/src/lox/Stmt.java b/src/lox/Stmt.java index e23ec6c..dbc11fc 100644 --- a/src/lox/Stmt.java +++ b/src/lox/Stmt.java @@ -6,8 +6,10 @@ abstract class Stmt { interface Visitor { R visitBlockStmt(Block stmt); R visitExpressionStmt(Expression stmt); + R visitFunctionStmt(Function stmt); R visitIfStmt(If stmt); R visitPrintStmt(Print stmt); + R visitReturnStmt(Return stmt); R visitVarStmt(Var stmt); R visitWhileStmt(While stmt); } @@ -35,6 +37,22 @@ abstract class Stmt { final Expr expression; } + static class Function extends Stmt { + Function(Token name, List params, List body) { + this.name = name; + this.params = params; + this.body = body; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitFunctionStmt(this); + } + + final Token name; + final List params; + final List body; + } static class If extends Stmt { If(Expr condition, Stmt thenBranch, Stmt elseBranch) { this.condition = condition; @@ -63,6 +81,20 @@ abstract class Stmt { final Expr expression; } + static class Return extends Stmt { + Return(Token keyword, Expr value) { + this.keyword = keyword; + this.value = value; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitReturnStmt(this); + } + + final Token keyword; + final Expr value; + } static class Var extends Stmt { Var(Token name, Expr initializer) { this.name = name; diff --git a/src/tool/GenerateAst.java b/src/tool/GenerateAst.java index 7142b39..eb853ac 100644 --- a/src/tool/GenerateAst.java +++ b/src/tool/GenerateAst.java @@ -16,6 +16,7 @@ public class GenerateAst { defineAst(outputDir, "Expr", Arrays.asList( "Assign : Token name, Expr value", "Binary : Expr left, Token operator, Expr right", + "Call : Expr callee, Token paren, List arguments", "Grouping : Expr expression", "Literal : Object value", "Logical : Expr left, Token operator, Expr right", @@ -26,9 +27,12 @@ public class GenerateAst { defineAst(outputDir, "Stmt", Arrays.asList( "Block : List statements", "Expression : Expr expression", + "Function : Token name, List params," + + " List body", "If : Expr condition, Stmt thenBranch," + " Stmt elseBranch", "Print : Expr expression", + "Return : Token keyword, Expr value", "Var : Token name, Expr initializer", "While : Expr condition, Stmt body" ));