diff --git a/src/lox/AstPrinter.java b/src/lox/AstPrinter.java index bfd8eb2..8d643c4 100644 --- a/src/lox/AstPrinter.java +++ b/src/lox/AstPrinter.java @@ -47,6 +47,11 @@ public class AstPrinter implements Expr.Visitor { return parenthesize(expr.name.lexeme, expr.object, expr.value); } + @Override + public String visitSuperExpr(Expr.Super expr) { + return parenthesize(expr.method.lexeme); + } + @Override public String visitThisExpr(Expr.This expr) { return parenthesize(expr.keyword.lexeme); diff --git a/src/lox/Expr.java b/src/lox/Expr.java index c4d7a0f..31eaaf2 100644 --- a/src/lox/Expr.java +++ b/src/lox/Expr.java @@ -12,6 +12,7 @@ abstract class Expr { R visitLiteralExpr(Literal expr); R visitLogicalExpr(Logical expr); R visitSetExpr(Set expr); + R visitSuperExpr(Super expr); R visitThisExpr(This expr); R visitUnaryExpr(Unary expr); R visitVariableExpr(Variable expr); @@ -132,6 +133,20 @@ abstract class Expr { final Token name; final Expr value; } + static class Super extends Expr { + Super(Token keyword, Token method) { + this.keyword = keyword; + this.method = method; + } + + @Override + R accept(Visitor visitor) { + return visitor.visitSuperExpr(this); + } + + final Token keyword; + final Token method; + } static class This extends Expr { This(Token keyword) { this.keyword = keyword; diff --git a/src/lox/Interpreter.java b/src/lox/Interpreter.java index 9a93f68..7715b11 100644 --- a/src/lox/Interpreter.java +++ b/src/lox/Interpreter.java @@ -80,7 +80,7 @@ public class Interpreter implements Expr.Visitor, return (double) left + (double) right; } if (left instanceof String && right instanceof String) { - return (String) left + (String) right; + return left + (String) right; } throw new RuntimeError(expr.operator, @@ -109,12 +109,11 @@ public class Interpreter implements Expr.Visitor, arguments.add(evaluate(argument)); } - if (!(callee instanceof LoxCallable)) { + if (!(callee instanceof LoxCallable function)) { throw new RuntimeError(expr.paren, "Can only call functions and classes."); } - LoxCallable function = (LoxCallable) callee; if (arguments.size() != function.arity()) { throw new RuntimeError(expr.paren, "Expected " + function.arity() + " arguments but got " + @@ -172,6 +171,25 @@ public class Interpreter implements Expr.Visitor, return value; } + @Override + public Object visitSuperExpr(Expr.Super expr) { + int distance = locals.get(expr); + LoxClass superclass = (LoxClass)environment.getAt( + distance, "super"); + + LoxInstance object = (LoxInstance) environment.getAt( + distance - 1, "this"); + + LoxFunction method = superclass.findMethod(expr.method.lexeme); + + if (method == null) { + throw new RuntimeError(expr.method, + "Undefined property '" + expr.method.lexeme + "'."); + } + + return method.bind(object); + } + @Override public Object visitThisExpr(Expr.This expr) { return lookUpVariable(expr.keyword, expr); @@ -278,8 +296,22 @@ public class Interpreter implements Expr.Visitor, @Override public Void visitClassStmt(Stmt.Class stmt) { + Object superclass = null; + if (stmt.superclass != null) { + superclass = evaluate(stmt.superclass); + if (!(superclass instanceof LoxClass)) { + throw new RuntimeError(stmt.superclass.name, + "Superclass must be a class."); + } + } + environment.define(stmt.name.lexeme, null); + if (stmt.superclass != null) { + environment = new Environment(environment); + environment.define("super", superclass); + } + Map methods = new HashMap<>(); for (Stmt.Function method : stmt.methods) { LoxFunction function = new LoxFunction(method, environment, @@ -287,7 +319,13 @@ public class Interpreter implements Expr.Visitor, methods.put(method.name.lexeme, function); } - LoxClass klass = new LoxClass(stmt.name.lexeme, methods); + LoxClass klass = new LoxClass(stmt.name.lexeme, + (LoxClass)superclass, methods); + + if (superclass != null) { + environment = environment.enclosing; + } + environment.assign(stmt.name, klass); return null; } diff --git a/src/lox/LoxClass.java b/src/lox/LoxClass.java index 297f5ba..e12af9b 100644 --- a/src/lox/LoxClass.java +++ b/src/lox/LoxClass.java @@ -5,10 +5,13 @@ import java.util.Map; public class LoxClass implements LoxCallable { final String name; + final LoxClass superclass; private final Map methods; - public LoxClass(String name, Map methods) { + public LoxClass(String name, LoxClass superclass, + Map methods) { this.name = name; + this.superclass = superclass; this.methods = methods; } @@ -17,6 +20,10 @@ public class LoxClass implements LoxCallable { return methods.get(name); } + if (superclass != null) { + return superclass.findMethod(name); + } + return null; } diff --git a/src/lox/Parser.java b/src/lox/Parser.java index 79408e6..31a2b99 100644 --- a/src/lox/Parser.java +++ b/src/lox/Parser.java @@ -45,6 +45,13 @@ public class Parser { private Stmt classDeclaration() { Token name = consume(IDENTIFIER, "Expect class name."); + + Expr.Variable superclass = null; + if (match(LESS)) { + consume(IDENTIFIER, "Expect superclass name."); + superclass = new Expr.Variable(previous()); + } + consume(LEFT_BRACE, "Expect '{' before class body."); List methods = new ArrayList<>(); @@ -54,7 +61,7 @@ public class Parser { consume(RIGHT_BRACE, "Expect '}' after class body."); - return new Stmt.Class(name, methods); + return new Stmt.Class(name, superclass, methods); } private Stmt statement() { @@ -210,8 +217,7 @@ public class Parser { if (expr instanceof Expr.Variable) { Token name = ((Expr.Variable) expr).name; return new Expr.Assign(name, value); - } else if (expr instanceof Expr.Get) { - Expr.Get get = (Expr.Get)expr; + } else if (expr instanceof Expr.Get get) { return new Expr.Set(get.object, get.name, value); } @@ -347,6 +353,14 @@ public class Parser { return new Expr.Literal(previous().literal); } + if (match(SUPER)) { + Token keyword = previous(); + consume(DOT, "Expect '.' after 'super'."); + Token method = consume(IDENTIFIER, + "Expect superclass method name."); + return new Expr.Super(keyword, method); + } + if (match(THIS)) return new Expr.This(previous()); if (match(IDENTIFIER)) { diff --git a/src/lox/Resolver.java b/src/lox/Resolver.java index 8d49eff..1a16bc2 100644 --- a/src/lox/Resolver.java +++ b/src/lox/Resolver.java @@ -23,7 +23,8 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { private enum ClassType { NONE, - CLASS + CLASS, + SUBCLASS } private ClassType currentClass = ClassType.NONE; @@ -90,6 +91,20 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { return null; } + @Override + public Void visitSuperExpr(Expr.Super expr) { + if (currentClass == ClassType.NONE) { + Lox.error(expr.keyword, + "Can't use 'super' outside of a class."); + } else if (currentClass != ClassType.SUBCLASS) { + Lox.error(expr.keyword, + "Can't use 'super' in a class with no superclass."); + } + + resolveLocal(expr, expr.keyword); + return null; + } + @Override public Void visitThisExpr(Expr.This expr) { if (currentClass == ClassType.NONE) { @@ -136,6 +151,22 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { declare(stmt.name); define(stmt.name); + if (stmt.superclass != null && + stmt.name.lexeme.equals(stmt.superclass.name.lexeme)) { + Lox.error(stmt.superclass.name, + "A class can't inherit from itself."); + } + + if (stmt.superclass != null) { + currentClass = ClassType.SUBCLASS; + resolve(stmt.superclass); + } + + if (stmt.superclass != null) { + beginScope(); + scopes.peek().put("super", true); + } + beginScope(); scopes.peek().put("this", true); @@ -149,6 +180,8 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { endScope(); + if (stmt.superclass != null) endScope(); + currentClass = enclosingClass; return null; } @@ -176,7 +209,7 @@ public class Resolver implements Expr.Visitor, Stmt.Visitor { } private void beginScope() { - scopes.push(new HashMap()); + scopes.push(new HashMap<>()); } private void endScope() { diff --git a/src/lox/Stmt.java b/src/lox/Stmt.java index 9d96fb2..f761132 100644 --- a/src/lox/Stmt.java +++ b/src/lox/Stmt.java @@ -27,8 +27,9 @@ abstract class Stmt { final List statements; } static class Class extends Stmt { - Class(Token name, List methods) { + Class(Token name, Expr.Variable superclass, List methods) { this.name = name; + this.superclass = superclass; this.methods = methods; } @@ -38,6 +39,7 @@ abstract class Stmt { } final Token name; + final Expr.Variable superclass; final List methods; } static class Expression extends Stmt { diff --git a/src/tool/GenerateAst.java b/src/tool/GenerateAst.java index 65e3089..c4c647a 100644 --- a/src/tool/GenerateAst.java +++ b/src/tool/GenerateAst.java @@ -1,13 +1,13 @@ package tool; -import java.io.FileNotFoundException; +import java.io.IOException; import java.io.PrintWriter; -import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; public class GenerateAst { - public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException { + public static void main(String[] args) throws IOException { if (args.length != 1) { System.err.println("Usage: generate_ast "); System.exit(64); @@ -22,6 +22,7 @@ public class GenerateAst { "Literal : Object value", "Logical : Expr left, Token operator, Expr right", "Set : Expr object, Token name, Expr value", + "Super : Token keyword, Token method", "This : Token keyword", "Unary : Token operator, Expr right", "Variable : Token name" @@ -29,7 +30,8 @@ public class GenerateAst { defineAst(outputDir, "Stmt", Arrays.asList( "Block : List statements", - "Class : Token name, List methods", + "Class : Token name, Expr.Variable superclass," + + " List methods", "Expression : Expr expression", "Function : Token name, List params," + " List body", @@ -43,9 +45,9 @@ public class GenerateAst { } private static void defineAst( - String outputDir, String baseName, List types) throws FileNotFoundException, UnsupportedEncodingException { + String outputDir, String baseName, List types) throws IOException { String path = outputDir + "/" + baseName + ".java"; - PrintWriter writer = new PrintWriter(path, "UTF-8"); + PrintWriter writer = new PrintWriter(path, StandardCharsets.UTF_8); writer.println("package lox;"); writer.println();