Skip to content

Implement more supernodes for incrementing/decrementing #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
*.txt
*.yml
*.zip
*.log
*.iml
*.dist
.classpath
.factorypath
.idea
.project
.pydevproject
.settings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@
import trufflesom.interpreter.nodes.ExpressionNode;
import trufflesom.interpreter.nodes.FieldNode;
import trufflesom.interpreter.nodes.FieldNode.FieldReadNode;
import trufflesom.interpreter.nodes.FieldNode.UninitFieldIncNode;
import trufflesom.interpreter.nodes.FieldNodeFactory.FieldWriteNodeGen;
import trufflesom.interpreter.nodes.ReturnNonLocalNode;
import trufflesom.interpreter.nodes.ReturnNonLocalNode.CatchNonLocalReturnNode;
import trufflesom.interpreter.nodes.literals.BlockNode;
import trufflesom.interpreter.supernodes.IntIncrementNode;
import trufflesom.interpreter.supernodes.inc.IncExpWithValueNode;
import trufflesom.interpreter.supernodes.LocalVariableSquareNode;
import trufflesom.interpreter.supernodes.NonLocalVariableSquareNode;
import trufflesom.interpreter.supernodes.inc.UninitIncFieldWithExpNode;
import trufflesom.primitives.Primitives;
import trufflesom.primitives.arithmetic.AdditionPrim;
import trufflesom.vmobjects.SClass;
import trufflesom.vmobjects.SInvokable;
import trufflesom.vmobjects.SInvokable.SMethod;
Expand Down Expand Up @@ -405,6 +406,10 @@ public ExpressionNode getLocalWriteNode(final Variable variable,
final ExpressionNode valExpr, final long coord) {
int ctxLevel = getContextLevel(variable);

if (valExpr instanceof IncExpWithValueNode inc && inc.doesAccessVariable(variable)) {
return inc.createIncVarNode((Local) variable, ctxLevel);
}

if (valExpr instanceof LocalVariableSquareNode l) {
return variable.getReadSquareWriteNode(ctxLevel, coord, l.getLocal(), 0);
}
Expand Down Expand Up @@ -450,8 +455,9 @@ public FieldReadNode getObjectFieldRead(final SSymbol fieldName,
return null;
}

return new FieldReadNode(getSelfRead(coord),
holderGenc.getFieldIndex(fieldName)).initialize(coord);
byte fieldIndex = holderGenc.getFieldIndex(fieldName);
ExpressionNode selfNode = getSelfRead(coord);
return new FieldReadNode(selfNode, fieldIndex).initialize(coord);
}

public FieldNode getObjectFieldWrite(final SSymbol fieldName, final ExpressionNode exp,
Expand All @@ -460,11 +466,22 @@ public FieldNode getObjectFieldWrite(final SSymbol fieldName, final ExpressionNo
return null;
}

int fieldIndex = holderGenc.getFieldIndex(fieldName);
byte fieldIndex = holderGenc.getFieldIndex(fieldName);
ExpressionNode self = getSelfRead(coord);
if (exp instanceof IntIncrementNode
&& ((IntIncrementNode) exp).doesAccessField(fieldIndex)) {
return new UninitFieldIncNode(self, fieldIndex, coord);
if (exp instanceof IncExpWithValueNode incNode && incNode.doesAccessField(fieldIndex)) {
return incNode.createIncFieldNode(self, fieldIndex, coord);
}

if (exp instanceof AdditionPrim add) {
ExpressionNode rcvr = add.getReceiver();
ExpressionNode arg = add.getArgument();

if (rcvr instanceof FieldReadNode fr && fieldIndex == fr.getFieldIndex()) {
return new UninitIncFieldWithExpNode(self, arg, true, fieldIndex, coord);
}
if (arg instanceof FieldReadNode fr && fieldIndex == fr.getFieldIndex()) {
return new UninitIncFieldWithExpNode(self, rcvr, false, fieldIndex, coord);
}
}

return FieldWriteNodeGen.create(fieldIndex, self, exp).initialize(coord);
Expand Down
13 changes: 7 additions & 6 deletions src/trufflesom/src/trufflesom/compiler/ParserAst.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
import trufflesom.interpreter.nodes.literals.GenericLiteralNode;
import trufflesom.interpreter.nodes.literals.IntegerLiteralNode;
import trufflesom.interpreter.nodes.literals.LiteralNode;
import trufflesom.interpreter.supernodes.IntIncrementNodeGen;
import trufflesom.interpreter.supernodes.LocalFieldStringEqualsNode;
import trufflesom.interpreter.supernodes.LocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.NonLocalFieldStringEqualsNode;
import trufflesom.interpreter.supernodes.NonLocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.StringEqualsNodeGen;
import trufflesom.interpreter.supernodes.inc.IncExpWithValueNodeGen;
import trufflesom.primitives.Primitives;
import trufflesom.vm.Globals;
import trufflesom.vm.NotYetImplementedException;
Expand Down Expand Up @@ -325,6 +325,12 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc,
rcvr.getContextLevel(), rcvr.getLocal()).initialize(coordWithL);
}
}
} else if (msg == SymbolTable.symPlus && operand instanceof IntegerLiteralNode lit) {
long litValue = lit.executeLong(null);
return IncExpWithValueNodeGen.create(litValue, false, receiver).initialize(coordWithL);
} else if (msg == SymbolTable.symMinus && operand instanceof IntegerLiteralNode lit) {
long litValue = lit.executeLong(null);
return IncExpWithValueNodeGen.create(-litValue, true, receiver).initialize(coordWithL);
}

ExpressionNode inlined =
Expand All @@ -334,11 +340,6 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc,
return inlined;
}

if (msg == SymbolTable.symPlus && operand instanceof IntegerLiteralNode lit) {
if (lit.executeLong(null) == 1) {
return IntIncrementNodeGen.create(receiver);
}
}
return MessageSendNode.create(msg, args, coordWithL);
}

Expand Down
62 changes: 58 additions & 4 deletions src/trufflesom/src/trufflesom/compiler/Variable.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import static trufflesom.compiler.bc.BytecodeGenerator.emitPUSHLOCAL;
import static trufflesom.vm.SymbolTable.strBlockSelf;
import static trufflesom.vm.SymbolTable.strSelf;
import static trufflesom.vm.SymbolTable.symBlockSelf;
import static trufflesom.vm.SymbolTable.symSelf;
import static trufflesom.vm.SymbolTable.symbolFor;

import java.util.Objects;

Expand All @@ -23,16 +20,22 @@
import trufflesom.interpreter.nodes.ArgumentReadNode.NonLocalArgumentReadNode;
import trufflesom.interpreter.nodes.ArgumentReadNode.NonLocalArgumentWriteNode;
import trufflesom.interpreter.nodes.ExpressionNode;
import trufflesom.interpreter.nodes.LocalVariableNode.LocalVariableReadNode;
import trufflesom.interpreter.nodes.LocalVariableNodeFactory.LocalVariableReadNodeGen;
import trufflesom.interpreter.nodes.LocalVariableNodeFactory.LocalVariableWriteNodeGen;
import trufflesom.interpreter.nodes.NonLocalVariableNode.NonLocalVariableReadNode;
import trufflesom.interpreter.nodes.NonLocalVariableNodeFactory.NonLocalVariableReadNodeGen;
import trufflesom.interpreter.nodes.NonLocalVariableNodeFactory.NonLocalVariableWriteNodeGen;
import trufflesom.interpreter.supernodes.LocalVariableReadSquareWriteNodeGen;
import trufflesom.interpreter.supernodes.LocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.NonLocalVariableReadSquareWriteNodeGen;
import trufflesom.interpreter.supernodes.NonLocalVariableSquareNodeGen;
import trufflesom.interpreter.supernodes.inc.IncLocalVarWithExpNodeGen;
import trufflesom.interpreter.supernodes.inc.IncLocalVarWithValueNodeGen;
import trufflesom.interpreter.supernodes.inc.IncNonLocalVarWithExpNodeGen;
import trufflesom.interpreter.supernodes.inc.IncNonLocalVarWithValueNodeGen;
import trufflesom.primitives.arithmetic.AdditionPrim;
import trufflesom.vm.NotYetImplementedException;
import trufflesom.vmobjects.SSymbol;


public abstract class Variable {
Expand Down Expand Up @@ -68,6 +71,8 @@ public abstract ExpressionNode getWriteNode(
public abstract ExpressionNode getReadSquareWriteNode(int writeContextLevel, long coord,
Local readLocal, int readContextLevel);

public abstract ExpressionNode getIncNode(int contextLevel, long incValue, long coord);

protected abstract void emitPop(BytecodeMethodGenContext mgenc);

protected abstract void emitPush(BytecodeMethodGenContext mgenc);
Expand Down Expand Up @@ -156,6 +161,12 @@ public ExpressionNode getReadSquareWriteNode(final int writeContextLevel, final
throw new NotYetImplementedException();
}

@Override
public ExpressionNode getIncNode(final int contextLevel, final long incValue,
final long coord) {
throw new NotYetImplementedException();
}

@Override
public void emitPop(final BytecodeMethodGenContext mgenc) {
emitPOPARGUMENT(mgenc, (byte) index, (byte) mgenc.getContextLevel(this));
Expand Down Expand Up @@ -215,6 +226,16 @@ public ExpressionNode getReadSquareWriteNode(final int writeContextLevel, final
return LocalVariableReadSquareWriteNodeGen.create(this, readLocal).initialize(coord);
}

@Override
public ExpressionNode getIncNode(final int contextLevel, final long incValue,
final long coord) {
if (contextLevel > 0) {
return IncNonLocalVarWithValueNodeGen.create(contextLevel, this, incValue)
.initialize(coord);
}
return IncLocalVarWithValueNodeGen.create(this, incValue).initialize(coord);
}

public final int getIndex() {
return slotIndex;
}
Expand All @@ -231,6 +252,31 @@ public Local splitToMergeIntoOuterScope(final int newSlotIndex) {

public ExpressionNode getWriteNode(final int contextLevel,
final ExpressionNode valueExpr, final long coordinate) {
if (valueExpr instanceof AdditionPrim add) {
ExpressionNode rcvr = add.getReceiver();
ExpressionNode arg = add.getArgument();

if (contextLevel > 0) {
if (rcvr instanceof NonLocalVariableReadNode nl && nl.getLocal() == this) {
return IncNonLocalVarWithExpNodeGen.create(contextLevel, this, arg)
.initialize(coord);
}

if (arg instanceof NonLocalVariableReadNode nl && nl.getLocal() == this) {
return IncNonLocalVarWithExpNodeGen.create(contextLevel, this, rcvr)
.initialize(coord);
}
} else {
if (rcvr instanceof LocalVariableReadNode l && l.getLocal() == this) {
return IncLocalVarWithExpNodeGen.create(this, arg).initialize(coord);
}

if (arg instanceof LocalVariableReadNode l && l.getLocal() == this) {
return IncLocalVarWithExpNodeGen.create(this, rcvr).initialize(coord);
}
}
}

if (contextLevel > 0) {
return NonLocalVariableWriteNodeGen.create(contextLevel, this, valueExpr)
.initialize(coordinate);
Expand Down Expand Up @@ -286,6 +332,14 @@ public ExpressionNode getReadSquareWriteNode(final int readContextLevel, final l
"There shouldn't be any language-level square nodes for internal slots. ");
}

@Override
public ExpressionNode getIncNode(final int contextLevel, final long incValue,
final long coord) {
throw new UnsupportedOperationException(
"There shouldn't be any language-level inc nodes for internal slots. "
+ "They are used directly by other nodes.");
}

@Override
public Internal split() {
return new Internal(name, coord, slotIndex);
Expand Down
88 changes: 0 additions & 88 deletions src/trufflesom/src/trufflesom/interpreter/nodes/FieldNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,94 +246,6 @@ public static ExpressionNode createForMethod(final int fieldIdx, final Argument
}
}

public static final class UninitFieldIncNode extends FieldNode {

@Child private ExpressionNode self;
private final int fieldIndex;

public UninitFieldIncNode(final ExpressionNode self, final int fieldIndex,
final long coord) {
this.self = self;
this.fieldIndex = fieldIndex;
this.sourceCoord = coord;
}

public int getFieldIndex() {
return fieldIndex;
}

@Override
public ExpressionNode getSelf() {
return self;
}

@Override
public Object doPreEvaluated(final VirtualFrame frame, final Object[] arguments) {
CompilerDirectives.transferToInterpreter();
throw new UnsupportedOperationException();
}

@Override
public Object executeGeneric(final VirtualFrame frame) {
CompilerDirectives.transferToInterpreterAndInvalidate();
SObject obj = (SObject) self.executeGeneric(frame);

Object val = obj.getField(fieldIndex);
if (!(val instanceof Long)) {
throw new NotYetImplementedException();
}

long longVal = 0;
try {
longVal = Math.addExact((Long) val, 1);
obj.setField(fieldIndex, longVal);
} catch (ArithmeticException e) {
throw new NotYetImplementedException();
}

IncrementLongFieldNode node = FieldAccessorNode.createIncrement(fieldIndex, obj);
IncFieldNode incNode = new IncFieldNode(self, node, sourceCoord);
replace(incNode);
node.notifyAsInserted();

return longVal;
}
}

private static final class IncFieldNode extends FieldNode {
@Child private ExpressionNode self;
@Child private IncrementLongFieldNode inc;

IncFieldNode(final ExpressionNode self, final IncrementLongFieldNode inc,
final long coord) {
initialize(coord);
this.self = self;
this.inc = inc;
}

@Override
public ExpressionNode getSelf() {
return self;
}

@Override
public Object doPreEvaluated(final VirtualFrame frame, final Object[] arguments) {
CompilerDirectives.transferToInterpreter();
throw new UnsupportedOperationException();
}

@Override
public Object executeGeneric(final VirtualFrame frame) {
return executeLong(frame);
}

@Override
public long executeLong(final VirtualFrame frame) {
SObject obj = (SObject) self.executeGeneric(frame);
return inc.increment(obj);
}
}

public static final class WriteAndReturnSelf extends ExpressionNode
implements PreevaluatedExpression {
@Child ExpressionNode write;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@
import com.oracle.truffle.api.nodes.LoopNode;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.RootNode;
import com.oracle.truffle.api.profiles.ValueProfile;

import trufflesom.bdt.inlining.ScopeAdaptationVisitor;
import trufflesom.bdt.inlining.nodes.ScopeReference;
Expand Down Expand Up @@ -866,7 +865,7 @@ public Object executeGeneric(final VirtualFrame frame) {
break;
}

((IncrementLongFieldNode) node).increment(obj);
((IncrementLongFieldNode) node).increment(obj, 1);
bytecodeIndex += Bytecodes.LEN_TWO_ARGS;
break;
}
Expand All @@ -890,7 +889,7 @@ public Object executeGeneric(final VirtualFrame frame) {
break;
}

long value = ((IncrementLongFieldNode) node).increment(obj);
long value = ((IncrementLongFieldNode) node).increment(obj, 1);
stackPointer += 1;
stack[stackPointer] = value;
bytecodeIndex += Bytecodes.LEN_TWO_ARGS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,24 +432,24 @@ private boolean hasExpectedLayout(final SObject obj)
return layout == obj.getObjectLayout();
}

public long increment(final SObject obj) {
public long increment(final SObject obj, final long incValue) {
try {
if (hasExpectedLayout(obj)) {
return storage.increment(obj);
return storage.increment(obj, incValue);
} else {
ensureNext(obj);
return nextInCache.increment(obj);
return nextInCache.increment(obj, incValue);
}
} catch (InvalidAssumptionException e) {
CompilerDirectives.transferToInterpreterAndInvalidate();
ensureNext(obj);
return dropAndIncrementNext(obj);
return dropAndIncrementNext(obj, incValue);
}
}

@InliningCutoff
private long dropAndIncrementNext(final SObject obj) {
return replace(SOMNode.unwrapIfNeeded(nextInCache)).increment(obj);
private long dropAndIncrementNext(final SObject obj, final long incValue) {
return replace(SOMNode.unwrapIfNeeded(nextInCache)).increment(obj, incValue);
}

@InliningCutoff
Expand Down
Loading
Loading