Inheritance chapter

This commit is contained in:
Mariano Riefolo 2024-07-05 09:44:06 +02:00
parent e1f123e96f
commit 95a4868bc7
8 changed files with 133 additions and 17 deletions

View File

@ -47,6 +47,11 @@ public class AstPrinter implements Expr.Visitor<String> {
return parenthesize(expr.name.lexeme, expr.object, expr.value); return parenthesize(expr.name.lexeme, expr.object, expr.value);
} }
@Override
public String visitSuperExpr(Expr.Super expr) {
return parenthesize(expr.method.lexeme);
}
@Override @Override
public String visitThisExpr(Expr.This expr) { public String visitThisExpr(Expr.This expr) {
return parenthesize(expr.keyword.lexeme); return parenthesize(expr.keyword.lexeme);

View File

@ -12,6 +12,7 @@ abstract class Expr {
R visitLiteralExpr(Literal expr); R visitLiteralExpr(Literal expr);
R visitLogicalExpr(Logical expr); R visitLogicalExpr(Logical expr);
R visitSetExpr(Set expr); R visitSetExpr(Set expr);
R visitSuperExpr(Super expr);
R visitThisExpr(This expr); R visitThisExpr(This expr);
R visitUnaryExpr(Unary expr); R visitUnaryExpr(Unary expr);
R visitVariableExpr(Variable expr); R visitVariableExpr(Variable expr);
@ -132,6 +133,20 @@ abstract class Expr {
final Token name; final Token name;
final Expr value; final Expr value;
} }
static class Super extends Expr {
Super(Token keyword, Token method) {
this.keyword = keyword;
this.method = method;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitSuperExpr(this);
}
final Token keyword;
final Token method;
}
static class This extends Expr { static class This extends Expr {
This(Token keyword) { This(Token keyword) {
this.keyword = keyword; this.keyword = keyword;

View File

@ -80,7 +80,7 @@ public class Interpreter implements Expr.Visitor<Object>,
return (double) left + (double) right; return (double) left + (double) right;
} }
if (left instanceof String && right instanceof String) { if (left instanceof String && right instanceof String) {
return (String) left + (String) right; return left + (String) right;
} }
throw new RuntimeError(expr.operator, throw new RuntimeError(expr.operator,
@ -109,12 +109,11 @@ public class Interpreter implements Expr.Visitor<Object>,
arguments.add(evaluate(argument)); arguments.add(evaluate(argument));
} }
if (!(callee instanceof LoxCallable)) { if (!(callee instanceof LoxCallable function)) {
throw new RuntimeError(expr.paren, throw new RuntimeError(expr.paren,
"Can only call functions and classes."); "Can only call functions and classes.");
} }
LoxCallable function = (LoxCallable) callee;
if (arguments.size() != function.arity()) { if (arguments.size() != function.arity()) {
throw new RuntimeError(expr.paren, "Expected " + throw new RuntimeError(expr.paren, "Expected " +
function.arity() + " arguments but got " + function.arity() + " arguments but got " +
@ -172,6 +171,25 @@ public class Interpreter implements Expr.Visitor<Object>,
return value; return value;
} }
@Override
public Object visitSuperExpr(Expr.Super expr) {
int distance = locals.get(expr);
LoxClass superclass = (LoxClass)environment.getAt(
distance, "super");
LoxInstance object = (LoxInstance) environment.getAt(
distance - 1, "this");
LoxFunction method = superclass.findMethod(expr.method.lexeme);
if (method == null) {
throw new RuntimeError(expr.method,
"Undefined property '" + expr.method.lexeme + "'.");
}
return method.bind(object);
}
@Override @Override
public Object visitThisExpr(Expr.This expr) { public Object visitThisExpr(Expr.This expr) {
return lookUpVariable(expr.keyword, expr); return lookUpVariable(expr.keyword, expr);
@ -278,8 +296,22 @@ public class Interpreter implements Expr.Visitor<Object>,
@Override @Override
public Void visitClassStmt(Stmt.Class stmt) { public Void visitClassStmt(Stmt.Class stmt) {
Object superclass = null;
if (stmt.superclass != null) {
superclass = evaluate(stmt.superclass);
if (!(superclass instanceof LoxClass)) {
throw new RuntimeError(stmt.superclass.name,
"Superclass must be a class.");
}
}
environment.define(stmt.name.lexeme, null); environment.define(stmt.name.lexeme, null);
if (stmt.superclass != null) {
environment = new Environment(environment);
environment.define("super", superclass);
}
Map<String, LoxFunction> methods = new HashMap<>(); Map<String, LoxFunction> methods = new HashMap<>();
for (Stmt.Function method : stmt.methods) { for (Stmt.Function method : stmt.methods) {
LoxFunction function = new LoxFunction(method, environment, LoxFunction function = new LoxFunction(method, environment,
@ -287,7 +319,13 @@ public class Interpreter implements Expr.Visitor<Object>,
methods.put(method.name.lexeme, function); methods.put(method.name.lexeme, function);
} }
LoxClass klass = new LoxClass(stmt.name.lexeme, methods); LoxClass klass = new LoxClass(stmt.name.lexeme,
(LoxClass)superclass, methods);
if (superclass != null) {
environment = environment.enclosing;
}
environment.assign(stmt.name, klass); environment.assign(stmt.name, klass);
return null; return null;
} }

View File

@ -5,10 +5,13 @@ import java.util.Map;
public class LoxClass implements LoxCallable { public class LoxClass implements LoxCallable {
final String name; final String name;
final LoxClass superclass;
private final Map<String, LoxFunction> methods; private final Map<String, LoxFunction> methods;
public LoxClass(String name, Map<String, LoxFunction> methods) { public LoxClass(String name, LoxClass superclass,
Map<String, LoxFunction> methods) {
this.name = name; this.name = name;
this.superclass = superclass;
this.methods = methods; this.methods = methods;
} }
@ -17,6 +20,10 @@ public class LoxClass implements LoxCallable {
return methods.get(name); return methods.get(name);
} }
if (superclass != null) {
return superclass.findMethod(name);
}
return null; return null;
} }

View File

@ -45,6 +45,13 @@ public class Parser {
private Stmt classDeclaration() { private Stmt classDeclaration() {
Token name = consume(IDENTIFIER, "Expect class name."); Token name = consume(IDENTIFIER, "Expect class name.");
Expr.Variable superclass = null;
if (match(LESS)) {
consume(IDENTIFIER, "Expect superclass name.");
superclass = new Expr.Variable(previous());
}
consume(LEFT_BRACE, "Expect '{' before class body."); consume(LEFT_BRACE, "Expect '{' before class body.");
List<Stmt.Function> methods = new ArrayList<>(); List<Stmt.Function> methods = new ArrayList<>();
@ -54,7 +61,7 @@ public class Parser {
consume(RIGHT_BRACE, "Expect '}' after class body."); consume(RIGHT_BRACE, "Expect '}' after class body.");
return new Stmt.Class(name, methods); return new Stmt.Class(name, superclass, methods);
} }
private Stmt statement() { private Stmt statement() {
@ -210,8 +217,7 @@ public class Parser {
if (expr instanceof Expr.Variable) { if (expr instanceof Expr.Variable) {
Token name = ((Expr.Variable) expr).name; Token name = ((Expr.Variable) expr).name;
return new Expr.Assign(name, value); return new Expr.Assign(name, value);
} else if (expr instanceof Expr.Get) { } else if (expr instanceof Expr.Get get) {
Expr.Get get = (Expr.Get)expr;
return new Expr.Set(get.object, get.name, value); return new Expr.Set(get.object, get.name, value);
} }
@ -347,6 +353,14 @@ public class Parser {
return new Expr.Literal(previous().literal); return new Expr.Literal(previous().literal);
} }
if (match(SUPER)) {
Token keyword = previous();
consume(DOT, "Expect '.' after 'super'.");
Token method = consume(IDENTIFIER,
"Expect superclass method name.");
return new Expr.Super(keyword, method);
}
if (match(THIS)) return new Expr.This(previous()); if (match(THIS)) return new Expr.This(previous());
if (match(IDENTIFIER)) { if (match(IDENTIFIER)) {

View File

@ -23,7 +23,8 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
private enum ClassType { private enum ClassType {
NONE, NONE,
CLASS CLASS,
SUBCLASS
} }
private ClassType currentClass = ClassType.NONE; private ClassType currentClass = ClassType.NONE;
@ -90,6 +91,20 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
return null; return null;
} }
@Override
public Void visitSuperExpr(Expr.Super expr) {
if (currentClass == ClassType.NONE) {
Lox.error(expr.keyword,
"Can't use 'super' outside of a class.");
} else if (currentClass != ClassType.SUBCLASS) {
Lox.error(expr.keyword,
"Can't use 'super' in a class with no superclass.");
}
resolveLocal(expr, expr.keyword);
return null;
}
@Override @Override
public Void visitThisExpr(Expr.This expr) { public Void visitThisExpr(Expr.This expr) {
if (currentClass == ClassType.NONE) { if (currentClass == ClassType.NONE) {
@ -136,6 +151,22 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
declare(stmt.name); declare(stmt.name);
define(stmt.name); define(stmt.name);
if (stmt.superclass != null &&
stmt.name.lexeme.equals(stmt.superclass.name.lexeme)) {
Lox.error(stmt.superclass.name,
"A class can't inherit from itself.");
}
if (stmt.superclass != null) {
currentClass = ClassType.SUBCLASS;
resolve(stmt.superclass);
}
if (stmt.superclass != null) {
beginScope();
scopes.peek().put("super", true);
}
beginScope(); beginScope();
scopes.peek().put("this", true); scopes.peek().put("this", true);
@ -149,6 +180,8 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
endScope(); endScope();
if (stmt.superclass != null) endScope();
currentClass = enclosingClass; currentClass = enclosingClass;
return null; return null;
} }
@ -176,7 +209,7 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
} }
private void beginScope() { private void beginScope() {
scopes.push(new HashMap<String, Boolean>()); scopes.push(new HashMap<>());
} }
private void endScope() { private void endScope() {

View File

@ -27,8 +27,9 @@ abstract class Stmt {
final List<Stmt> statements; final List<Stmt> statements;
} }
static class Class extends Stmt { static class Class extends Stmt {
Class(Token name, List<Stmt.Function> methods) { Class(Token name, Expr.Variable superclass, List<Stmt.Function> methods) {
this.name = name; this.name = name;
this.superclass = superclass;
this.methods = methods; this.methods = methods;
} }
@ -38,6 +39,7 @@ abstract class Stmt {
} }
final Token name; final Token name;
final Expr.Variable superclass;
final List<Stmt.Function> methods; final List<Stmt.Function> methods;
} }
static class Expression extends Stmt { static class Expression extends Stmt {

View File

@ -1,13 +1,13 @@
package tool; package tool;
import java.io.FileNotFoundException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.io.UnsupportedEncodingException; import java.nio.charset.StandardCharsets;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
public class GenerateAst { public class GenerateAst {
public static void main(String[] args) throws FileNotFoundException, UnsupportedEncodingException { public static void main(String[] args) throws IOException {
if (args.length != 1) { if (args.length != 1) {
System.err.println("Usage: generate_ast <output directory>"); System.err.println("Usage: generate_ast <output directory>");
System.exit(64); System.exit(64);
@ -22,6 +22,7 @@ public class GenerateAst {
"Literal : Object value", "Literal : Object value",
"Logical : Expr left, Token operator, Expr right", "Logical : Expr left, Token operator, Expr right",
"Set : Expr object, Token name, Expr value", "Set : Expr object, Token name, Expr value",
"Super : Token keyword, Token method",
"This : Token keyword", "This : Token keyword",
"Unary : Token operator, Expr right", "Unary : Token operator, Expr right",
"Variable : Token name" "Variable : Token name"
@ -29,7 +30,8 @@ public class GenerateAst {
defineAst(outputDir, "Stmt", Arrays.asList( defineAst(outputDir, "Stmt", Arrays.asList(
"Block : List<Stmt> statements", "Block : List<Stmt> statements",
"Class : Token name, List<Stmt.Function> methods", "Class : Token name, Expr.Variable superclass," +
" List<Stmt.Function> methods",
"Expression : Expr expression", "Expression : Expr expression",
"Function : Token name, List<Token> params," + "Function : Token name, List<Token> params," +
" List<Stmt> body", " List<Stmt> body",
@ -43,9 +45,9 @@ public class GenerateAst {
} }
private static void defineAst( private static void defineAst(
String outputDir, String baseName, List<String> types) throws FileNotFoundException, UnsupportedEncodingException { String outputDir, String baseName, List<String> types) throws IOException {
String path = outputDir + "/" + baseName + ".java"; String path = outputDir + "/" + baseName + ".java";
PrintWriter writer = new PrintWriter(path, "UTF-8"); PrintWriter writer = new PrintWriter(path, StandardCharsets.UTF_8);
writer.println("package lox;"); writer.println("package lox;");
writer.println(); writer.println();