Classes chapter

This commit is contained in:
Mariano Riefolo 2024-07-04 14:45:34 +02:00
parent c7e62694f8
commit e1f123e96f
10 changed files with 316 additions and 11 deletions

View File

@ -7,23 +7,28 @@ public class AstPrinter implements Expr.Visitor<String> {
@Override @Override
public String visitAssignExpr(Expr.Assign expr) { public String visitAssignExpr(Expr.Assign expr) {
return parethesize(expr.name.lexeme, expr.value); return parenthesize(expr.name.lexeme, expr.value);
} }
@Override @Override
public String visitBinaryExpr(Expr.Binary expr) { public String visitBinaryExpr(Expr.Binary expr) {
return parethesize(expr.operator.lexeme, return parenthesize(expr.operator.lexeme,
expr.left, expr.right); expr.left, expr.right);
} }
@Override @Override
public String visitCallExpr(Expr.Call expr) { public String visitCallExpr(Expr.Call expr) {
return parethesize(expr.paren.lexeme, expr.callee); return parenthesize(expr.paren.lexeme, expr.callee);
}
@Override
public String visitGetExpr(Expr.Get expr) {
return parenthesize(expr.name.lexeme, expr.object);
} }
@Override @Override
public String visitGroupingExpr(Expr.Grouping expr) { public String visitGroupingExpr(Expr.Grouping expr) {
return parethesize("group", expr.expression); return parenthesize("group", expr.expression);
} }
@Override @Override
@ -34,20 +39,31 @@ public class AstPrinter implements Expr.Visitor<String> {
@Override @Override
public String visitLogicalExpr(Expr.Logical expr) { public String visitLogicalExpr(Expr.Logical expr) {
return parethesize(expr.operator.lexeme, expr.left, expr.right); return parenthesize(expr.operator.lexeme, expr.left, expr.right);
} }
@Override
public String visitSetExpr(Expr.Set expr) {
return parenthesize(expr.name.lexeme, expr.object, expr.value);
}
@Override
public String visitThisExpr(Expr.This expr) {
return parenthesize(expr.keyword.lexeme);
}
// https://craftinginterpreters.com/classes.html#invalid-uses-of-this
@Override @Override
public String visitUnaryExpr(Expr.Unary expr) { public String visitUnaryExpr(Expr.Unary expr) {
return parethesize(expr.operator.lexeme, expr.right); return parenthesize(expr.operator.lexeme, expr.right);
} }
@Override @Override
public String visitVariableExpr(Expr.Variable expr) { public String visitVariableExpr(Expr.Variable expr) {
return parethesize(expr.name.lexeme); return parenthesize(expr.name.lexeme);
} }
private String parethesize(String name, Expr... exprs) { private String parenthesize(String name, Expr... exprs) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
builder.append("(").append(name); builder.append("(").append(name);

View File

@ -7,9 +7,12 @@ abstract class Expr {
R visitAssignExpr(Assign expr); R visitAssignExpr(Assign expr);
R visitBinaryExpr(Binary expr); R visitBinaryExpr(Binary expr);
R visitCallExpr(Call expr); R visitCallExpr(Call expr);
R visitGetExpr(Get expr);
R visitGroupingExpr(Grouping expr); R visitGroupingExpr(Grouping expr);
R visitLiteralExpr(Literal expr); R visitLiteralExpr(Literal expr);
R visitLogicalExpr(Logical expr); R visitLogicalExpr(Logical expr);
R visitSetExpr(Set expr);
R visitThisExpr(This expr);
R visitUnaryExpr(Unary expr); R visitUnaryExpr(Unary expr);
R visitVariableExpr(Variable expr); R visitVariableExpr(Variable expr);
} }
@ -59,6 +62,20 @@ abstract class Expr {
final Token paren; final Token paren;
final List<Expr> arguments; final List<Expr> arguments;
} }
static class Get extends Expr {
Get(Expr object, Token name) {
this.object = object;
this.name = name;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitGetExpr(this);
}
final Expr object;
final Token name;
}
static class Grouping extends Expr { static class Grouping extends Expr {
Grouping(Expr expression) { Grouping(Expr expression) {
this.expression = expression; this.expression = expression;
@ -99,6 +116,34 @@ abstract class Expr {
final Token operator; final Token operator;
final Expr right; final Expr right;
} }
static class Set extends Expr {
Set(Expr object, Token name, Expr value) {
this.object = object;
this.name = name;
this.value = value;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitSetExpr(this);
}
final Expr object;
final Token name;
final Expr value;
}
static class This extends Expr {
This(Token keyword) {
this.keyword = keyword;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitThisExpr(this);
}
final Token keyword;
}
static class Unary extends Expr { static class Unary extends Expr {
Unary(Token operator, Expr right) { Unary(Token operator, Expr right) {
this.operator = operator; this.operator = operator;

View File

@ -124,6 +124,17 @@ public class Interpreter implements Expr.Visitor<Object>,
return function.call(this, arguments); return function.call(this, arguments);
} }
@Override
public Object visitGetExpr(Expr.Get expr) {
Object object = evaluate(expr.object);
if (object instanceof LoxInstance) {
return ((LoxInstance) object).get(expr.name);
}
throw new RuntimeError(expr.name,
"Only instances have properties.");
}
@Override @Override
public Object visitGroupingExpr(Expr.Grouping expr) { public Object visitGroupingExpr(Expr.Grouping expr) {
return evaluate(expr.expression); return evaluate(expr.expression);
@ -147,6 +158,25 @@ public class Interpreter implements Expr.Visitor<Object>,
return evaluate(expr.right); return evaluate(expr.right);
} }
@Override
public Object visitSetExpr(Expr.Set expr) {
Object object = evaluate(expr.object);
if (!(object instanceof LoxInstance)) {
throw new RuntimeError(expr.name,
"Only instances have fields.");
}
Object value = evaluate(expr.value);
((LoxInstance)object).set(expr.name, value);
return value;
}
@Override
public Object visitThisExpr(Expr.This expr) {
return lookUpVariable(expr.keyword, expr);
}
@Override @Override
public Object visitUnaryExpr(Expr.Unary expr) { public Object visitUnaryExpr(Expr.Unary expr) {
Object right = evaluate(expr.right); Object right = evaluate(expr.right);
@ -246,6 +276,22 @@ public class Interpreter implements Expr.Visitor<Object>,
return null; return null;
} }
@Override
public Void visitClassStmt(Stmt.Class stmt) {
environment.define(stmt.name.lexeme, null);
Map<String, LoxFunction> methods = new HashMap<>();
for (Stmt.Function method : stmt.methods) {
LoxFunction function = new LoxFunction(method, environment,
method.name.lexeme.equals("init"));
methods.put(method.name.lexeme, function);
}
LoxClass klass = new LoxClass(stmt.name.lexeme, methods);
environment.assign(stmt.name, klass);
return null;
}
@Override @Override
public Void visitExpressionStmt(Stmt.Expression stmt) { public Void visitExpressionStmt(Stmt.Expression stmt) {
evaluate(stmt.expression); evaluate(stmt.expression);
@ -254,7 +300,7 @@ public class Interpreter implements Expr.Visitor<Object>,
@Override @Override
public Void visitFunctionStmt(Stmt.Function stmt) { public Void visitFunctionStmt(Stmt.Function stmt) {
LoxFunction function = new LoxFunction(stmt, environment); LoxFunction function = new LoxFunction(stmt, environment, false);
environment.define(stmt.name.lexeme, function); environment.define(stmt.name.lexeme, function);
return null; return null;
} }

44
src/lox/LoxClass.java Normal file
View File

@ -0,0 +1,44 @@
package lox;
import java.util.List;
import java.util.Map;
public class LoxClass implements LoxCallable {
final String name;
private final Map<String, LoxFunction> methods;
public LoxClass(String name, Map<String, LoxFunction> methods) {
this.name = name;
this.methods = methods;
}
LoxFunction findMethod(String name) {
if (methods.containsKey(name)) {
return methods.get(name);
}
return null;
}
@Override
public String toString() {
return name;
}
@Override
public int arity() {
LoxFunction initializer = findMethod("init");
if (initializer == null) return 0;
return initializer.arity();
}
@Override
public Object call(Interpreter interpreter, List<Object> arguments) {
LoxInstance instance = new LoxInstance(this);
LoxFunction initializer = findMethod("init");
if (initializer != null) {
initializer.bind(instance).call(interpreter, arguments);
}
return instance;
}
}

View File

@ -6,9 +6,19 @@ public class LoxFunction implements LoxCallable {
private final Stmt.Function declaration; private final Stmt.Function declaration;
private final Environment closure; private final Environment closure;
LoxFunction(Stmt.Function declaration, Environment closure) { private final boolean isInitializer;
LoxFunction(Stmt.Function declaration, Environment closure, boolean isInitializer) {
this.closure = closure; this.closure = closure;
this.declaration = declaration; this.declaration = declaration;
this.isInitializer = isInitializer;
}
LoxFunction bind(LoxInstance instance) {
Environment environment = new Environment(closure);
environment.define("this", instance);
return new LoxFunction(declaration, environment,
isInitializer);
} }
@Override @Override
@ -27,8 +37,12 @@ public class LoxFunction implements LoxCallable {
try { try {
interpreter.executeBlock(declaration.body, environment); interpreter.executeBlock(declaration.body, environment);
} catch (Return returnValue) { } catch (Return returnValue) {
if (isInitializer) return closure.getAt(0, "this");
return returnValue.value; return returnValue.value;
} }
if (isInitializer) return closure.getAt(0, "this");
return null; return null;
} }

34
src/lox/LoxInstance.java Normal file
View File

@ -0,0 +1,34 @@
package lox;
import java.util.HashMap;
import java.util.Map;
public class LoxInstance {
private LoxClass klass;
private final Map<String, Object> fields = new HashMap<>();
LoxInstance(LoxClass klass) {
this.klass = klass;
}
Object get(Token name) {
if (fields.containsKey(name.lexeme)) {
return fields.get(name.lexeme);
}
LoxFunction method = klass.findMethod(name.lexeme);
if (method != null) return method.bind(this);
throw new RuntimeError(name,
"Undefined property '" + name.lexeme + "'.");
}
void set(Token name, Object value) {
fields.put(name.lexeme, value);
}
@Override
public String toString() {
return klass.name + " instance";
}
}

View File

@ -32,6 +32,7 @@ public class Parser {
private Stmt declaration() { private Stmt declaration() {
try { try {
if (match(CLASS)) return classDeclaration();
if (match(FUN)) return function("function"); if (match(FUN)) return function("function");
if (match(VAR)) return varDeclaration(); if (match(VAR)) return varDeclaration();
@ -42,6 +43,20 @@ public class Parser {
} }
} }
private Stmt classDeclaration() {
Token name = consume(IDENTIFIER, "Expect class name.");
consume(LEFT_BRACE, "Expect '{' before class body.");
List<Stmt.Function> methods = new ArrayList<>();
while (!check(RIGHT_BRACE) && !isAtEnd()) {
methods.add(function("method"));
}
consume(RIGHT_BRACE, "Expect '}' after class body.");
return new Stmt.Class(name, methods);
}
private Stmt statement() { private Stmt statement() {
if (match(FOR)) return forStatement(); if (match(FOR)) return forStatement();
if (match(IF)) return ifStatement(); if (match(IF)) return ifStatement();
@ -195,6 +210,9 @@ public class Parser {
if (expr instanceof Expr.Variable) { if (expr instanceof Expr.Variable) {
Token name = ((Expr.Variable) expr).name; Token name = ((Expr.Variable) expr).name;
return new Expr.Assign(name, value); return new Expr.Assign(name, value);
} else if (expr instanceof Expr.Get) {
Expr.Get get = (Expr.Get)expr;
return new Expr.Set(get.object, get.name, value);
} }
throw error(equals, "Invalid assignment target."); throw error(equals, "Invalid assignment target.");
@ -308,6 +326,10 @@ public class Parser {
while (true) { while (true) {
if (match(LEFT_PAREN)) { if (match(LEFT_PAREN)) {
expr = finishCall(expr); expr = finishCall(expr);
} else if (match(DOT)) {
Token name = consume(IDENTIFIER,
"Expect property name after '.'.");
expr = new Expr.Get(expr, name);
} else { } else {
break; break;
} }
@ -325,6 +347,8 @@ public class Parser {
return new Expr.Literal(previous().literal); return new Expr.Literal(previous().literal);
} }
if (match(THIS)) return new Expr.This(previous());
if (match(IDENTIFIER)) { if (match(IDENTIFIER)) {
return new Expr.Variable(previous()); return new Expr.Variable(previous());
} }

View File

@ -16,9 +16,18 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
private enum FunctionType { private enum FunctionType {
NONE, NONE,
FUNCTION FUNCTION,
INITIALIZER,
METHOD
} }
private enum ClassType {
NONE,
CLASS
}
private ClassType currentClass = ClassType.NONE;
void resolve(List<Stmt> statements) { void resolve(List<Stmt> statements) {
for (Stmt statement : statements) { for (Stmt statement : statements) {
resolve(statement); resolve(statement);
@ -50,6 +59,12 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
return null; return null;
} }
@Override
public Void visitGetExpr(Expr.Get expr) {
resolve(expr.object);
return null;
}
@Override @Override
public Void visitGroupingExpr(Expr.Grouping expr) { public Void visitGroupingExpr(Expr.Grouping expr) {
resolve(expr.expression); resolve(expr.expression);
@ -68,6 +83,25 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
return null; return null;
} }
@Override
public Void visitSetExpr(Expr.Set expr) {
resolve(expr.value);
resolve(expr.object);
return null;
}
@Override
public Void visitThisExpr(Expr.This expr) {
if (currentClass == ClassType.NONE) {
Lox.error(expr.keyword,
"Can't use 'this' outside of a class.");
return null;
}
resolveLocal(expr, expr.keyword);
return null;
}
@Override @Override
public Void visitUnaryExpr(Expr.Unary expr) { public Void visitUnaryExpr(Expr.Unary expr) {
resolve(expr.right); resolve(expr.right);
@ -94,6 +128,31 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
return null; return null;
} }
@Override
public Void visitClassStmt(Stmt.Class stmt) {
ClassType enclosingClass = currentClass;
currentClass = ClassType.CLASS;
declare(stmt.name);
define(stmt.name);
beginScope();
scopes.peek().put("this", true);
for (Stmt.Function method : stmt.methods) {
FunctionType declaration = FunctionType.METHOD;
if (method.name.lexeme.equals("init")) {
declaration = FunctionType.INITIALIZER;
}
resolveFunction(method, declaration);
}
endScope();
currentClass = enclosingClass;
return null;
}
private void resolve(Stmt stmt) { private void resolve(Stmt stmt) {
stmt.accept(this); stmt.accept(this);
} }
@ -185,6 +244,10 @@ public class Resolver implements Expr.Visitor<Void>, Stmt.Visitor<Void> {
} }
if (stmt.value != null) { if (stmt.value != null) {
if (currentFunction == FunctionType.INITIALIZER) {
Lox.error(stmt.keyword,
"Can't return a value from an initializer.");
}
resolve(stmt.value); resolve(stmt.value);
} }

View File

@ -5,6 +5,7 @@ import java.util.List;
abstract class Stmt { abstract class Stmt {
interface Visitor<R> { interface Visitor<R> {
R visitBlockStmt(Block stmt); R visitBlockStmt(Block stmt);
R visitClassStmt(Class stmt);
R visitExpressionStmt(Expression stmt); R visitExpressionStmt(Expression stmt);
R visitFunctionStmt(Function stmt); R visitFunctionStmt(Function stmt);
R visitIfStmt(If stmt); R visitIfStmt(If stmt);
@ -25,6 +26,20 @@ abstract class Stmt {
final List<Stmt> statements; final List<Stmt> statements;
} }
static class Class extends Stmt {
Class(Token name, List<Stmt.Function> methods) {
this.name = name;
this.methods = methods;
}
@Override
<R> R accept(Visitor<R> visitor) {
return visitor.visitClassStmt(this);
}
final Token name;
final List<Stmt.Function> methods;
}
static class Expression extends Stmt { static class Expression extends Stmt {
Expression(Expr expression) { Expression(Expr expression) {
this.expression = expression; this.expression = expression;

View File

@ -17,15 +17,19 @@ public class GenerateAst {
"Assign : Token name, Expr value", "Assign : Token name, Expr value",
"Binary : Expr left, Token operator, Expr right", "Binary : Expr left, Token operator, Expr right",
"Call : Expr callee, Token paren, List<Expr> arguments", "Call : Expr callee, Token paren, List<Expr> arguments",
"Get : Expr object, Token name",
"Grouping : Expr expression", "Grouping : Expr expression",
"Literal : Object value", "Literal : Object value",
"Logical : Expr left, Token operator, Expr right", "Logical : Expr left, Token operator, Expr right",
"Set : Expr object, Token name, Expr value",
"This : Token keyword",
"Unary : Token operator, Expr right", "Unary : Token operator, Expr right",
"Variable : Token name" "Variable : Token name"
)); ));
defineAst(outputDir, "Stmt", Arrays.asList( defineAst(outputDir, "Stmt", Arrays.asList(
"Block : List<Stmt> statements", "Block : List<Stmt> statements",
"Class : Token name, List<Stmt.Function> methods",
"Expression : Expr expression", "Expression : Expr expression",
"Function : Token name, List<Token> params," + "Function : Token name, List<Token> params," +
" List<Stmt> body", " List<Stmt> body",