diff --git a/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java b/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java index eafbf5c0c..6eadcea58 100644 --- a/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java +++ b/src/trufflesom/src/trufflesom/compiler/MethodGenerationContext.java @@ -55,7 +55,7 @@ import trufflesom.interpreter.nodes.ReturnNonLocalNode; import trufflesom.interpreter.nodes.ReturnNonLocalNode.CatchNonLocalReturnNode; import trufflesom.interpreter.nodes.literals.BlockNode; -import trufflesom.interpreter.nodes.specialized.IntIncrementNode; +import trufflesom.interpreter.supernodes.IntIncrementNode; import trufflesom.primitives.Primitives; import trufflesom.vmobjects.SClass; import trufflesom.vmobjects.SInvokable; diff --git a/src/trufflesom/src/trufflesom/compiler/ParserAst.java b/src/trufflesom/src/trufflesom/compiler/ParserAst.java index 6e503e0d8..f9cd9590a 100644 --- a/src/trufflesom/src/trufflesom/compiler/ParserAst.java +++ b/src/trufflesom/src/trufflesom/compiler/ParserAst.java @@ -27,8 +27,11 @@ import trufflesom.bdt.basic.ProgramDefinitionError; import trufflesom.bdt.inlining.InlinableNodes; import trufflesom.bdt.tools.structure.StructuralProbe; +import trufflesom.interpreter.nodes.ArgumentReadNode.LocalArgumentReadNode; +import trufflesom.interpreter.nodes.ArgumentReadNode.NonLocalArgumentReadNode; import trufflesom.interpreter.nodes.ExpressionNode; import trufflesom.interpreter.nodes.FieldNode; +import trufflesom.interpreter.nodes.FieldNode.FieldReadNode; import trufflesom.interpreter.nodes.GlobalNode; import trufflesom.interpreter.nodes.MessageSendNode; import trufflesom.interpreter.nodes.SequenceNode; @@ -38,9 +41,14 @@ import trufflesom.interpreter.nodes.literals.GenericLiteralNode; import trufflesom.interpreter.nodes.literals.IntegerLiteralNode; import trufflesom.interpreter.nodes.literals.LiteralNode; -import trufflesom.interpreter.nodes.specialized.IntIncrementNodeGen; +import trufflesom.interpreter.supernodes.IntIncrementNodeGen; +import trufflesom.interpreter.supernodes.LocalFieldStringEqualsNode; +import trufflesom.interpreter.supernodes.NonLocalFieldStringEqualsNode; +import trufflesom.interpreter.supernodes.StringEqualsNodeGen; import trufflesom.primitives.Primitives; import trufflesom.vm.Globals; +import trufflesom.vm.NotYetImplementedException; +import trufflesom.vm.SymbolTable; import trufflesom.vmobjects.SArray; import trufflesom.vmobjects.SClass; import trufflesom.vmobjects.SInvokable; @@ -255,6 +263,50 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc, mgenc.getHolder().getSuperClass(), msg, args, coordWithL); } + String binSelector = msg.getString(); + + if (binSelector.equals("=")) { + if (operand instanceof GenericLiteralNode) { + Object literal = operand.executeGeneric(null); + if (literal instanceof String s) { + if (receiver instanceof FieldReadNode fieldRead) { + ExpressionNode self = fieldRead.getSelf(); + if (self instanceof LocalArgumentReadNode localSelf) { + return new LocalFieldStringEqualsNode(fieldRead.getFieldIndex(), + localSelf.getArg(), s).initialize(coordWithL); + } else if (self instanceof NonLocalArgumentReadNode arg) { + return new NonLocalFieldStringEqualsNode(fieldRead.getFieldIndex(), arg.getArg(), + arg.getContextLevel(), s).initialize(coordWithL); + } else { + throw new NotYetImplementedException(); + } + } + + return StringEqualsNodeGen.create(s, receiver).initialize(coordWithL); + } + } + + if (receiver instanceof GenericLiteralNode) { + Object literal = receiver.executeGeneric(null); + if (literal instanceof String s) { + if (operand instanceof FieldReadNode fieldRead) { + ExpressionNode self = fieldRead.getSelf(); + if (self instanceof LocalArgumentReadNode localSelf) { + return new LocalFieldStringEqualsNode(fieldRead.getFieldIndex(), + localSelf.getArg(), s).initialize(coordWithL); + } else if (self instanceof NonLocalArgumentReadNode arg) { + return new NonLocalFieldStringEqualsNode(fieldRead.getFieldIndex(), arg.getArg(), + arg.getContextLevel(), s).initialize(coordWithL); + } else { + throw new NotYetImplementedException(); + } + } + + return StringEqualsNodeGen.create(s, operand).initialize(coordWithL); + } + } + } + ExpressionNode inlined = inlinableNodes.inline(msg, args, mgenc, coordWithL); if (inlined != null) { @@ -262,8 +314,7 @@ protected ExpressionNode binaryMessage(final MethodGenerationContext mgenc, return inlined; } - if (msg.getString().equals("+") && operand instanceof IntegerLiteralNode) { - IntegerLiteralNode lit = (IntegerLiteralNode) operand; + if (msg == SymbolTable.symPlus && operand instanceof IntegerLiteralNode lit) { if (lit.executeLong(null) == 1) { return IntIncrementNodeGen.create(receiver); } diff --git a/src/trufflesom/src/trufflesom/compiler/Variable.java b/src/trufflesom/src/trufflesom/compiler/Variable.java index 2029f5993..5826a4f57 100644 --- a/src/trufflesom/src/trufflesom/compiler/Variable.java +++ b/src/trufflesom/src/trufflesom/compiler/Variable.java @@ -1,6 +1,5 @@ package trufflesom.compiler; -import static com.oracle.truffle.api.CompilerDirectives.transferToInterpreterAndInvalidate; import static trufflesom.compiler.bc.BytecodeGenerator.emitPOPARGUMENT; import static trufflesom.compiler.bc.BytecodeGenerator.emitPOPLOCAL; import static trufflesom.compiler.bc.BytecodeGenerator.emitPUSHARGUMENT; @@ -116,8 +115,6 @@ public Local splitToMergeIntoOuterScope(final int newSlotIndex) { @Override public ExpressionNode getReadNode(final int contextLevel, final long coordinate) { - transferToInterpreterAndInvalidate(); - if (contextLevel == 0) { return new LocalArgumentReadNode(this).initialize(coordinate); } else { @@ -128,8 +125,6 @@ public ExpressionNode getReadNode(final int contextLevel, final long coordinate) @Override public ExpressionNode getWriteNode(final int contextLevel, final ExpressionNode valueExpr, final long coordinate) { - transferToInterpreterAndInvalidate(); - if (contextLevel == 0) { return new LocalArgumentWriteNode(this, valueExpr).initialize(coordinate); } else { @@ -172,7 +167,6 @@ public void init(final FrameDescriptor desc) { @Override public ExpressionNode getReadNode(final int contextLevel, final long coordinate) { - transferToInterpreterAndInvalidate(); if (contextLevel > 0) { return NonLocalVariableReadNodeGen.create(contextLevel, this).initialize(coordinate); } @@ -196,7 +190,6 @@ public Local splitToMergeIntoOuterScope(final int newSlotIndex) { @Override public ExpressionNode getWriteNode(final int contextLevel, final ExpressionNode valueExpr, final long coordinate) { - transferToInterpreterAndInvalidate(); if (contextLevel > 0) { return NonLocalVariableWriteNodeGen.create(contextLevel, this, valueExpr) .initialize(coordinate); diff --git a/src/trufflesom/src/trufflesom/interpreter/nodes/ArgumentReadNode.java b/src/trufflesom/src/trufflesom/interpreter/nodes/ArgumentReadNode.java index 77ffa8704..140cfb63f 100644 --- a/src/trufflesom/src/trufflesom/interpreter/nodes/ArgumentReadNode.java +++ b/src/trufflesom/src/trufflesom/interpreter/nodes/ArgumentReadNode.java @@ -39,6 +39,10 @@ public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { inliner.updateRead(arg, this, 0); } + public Argument getArg() { + return arg; + } + @Override public SSymbol getInvocationIdentifier() { return arg.name; @@ -118,6 +122,10 @@ public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { inliner.updateRead(arg, this, contextLevel); } + public Argument getArg() { + return arg; + } + @Override public SSymbol getInvocationIdentifier() { return arg.name; diff --git a/src/trufflesom/src/trufflesom/interpreter/nodes/specialized/IntIncrementNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/IntIncrementNode.java similarity index 95% rename from src/trufflesom/src/trufflesom/interpreter/nodes/specialized/IntIncrementNode.java rename to src/trufflesom/src/trufflesom/interpreter/supernodes/IntIncrementNode.java index 88f1335a3..ff0cc2163 100644 --- a/src/trufflesom/src/trufflesom/interpreter/nodes/specialized/IntIncrementNode.java +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/IntIncrementNode.java @@ -1,4 +1,4 @@ -package trufflesom.interpreter.nodes.specialized; +package trufflesom.interpreter.supernodes; import com.oracle.truffle.api.dsl.NodeChild; import com.oracle.truffle.api.dsl.Specialization; diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalFieldStringEqualsNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalFieldStringEqualsNode.java new file mode 100644 index 000000000..367acf7dc --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/LocalFieldStringEqualsNode.java @@ -0,0 +1,167 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.nodes.UnexpectedResultException; + +import trufflesom.bdt.inlining.ScopeAdaptationVisitor; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement; +import trufflesom.compiler.Variable.Argument; +import trufflesom.interpreter.bc.RespecializeException; +import trufflesom.interpreter.nodes.ArgumentReadNode.LocalArgumentReadNode; +import trufflesom.interpreter.nodes.ExpressionNode; +import trufflesom.interpreter.nodes.FieldNode.FieldReadNode; +import trufflesom.interpreter.nodes.GenericMessageSendNode; +import trufflesom.interpreter.nodes.MessageSendNode; +import trufflesom.interpreter.nodes.bc.BytecodeLoopNode; +import trufflesom.interpreter.nodes.literals.GenericLiteralNode; +import trufflesom.interpreter.objectstorage.FieldAccessorNode; +import trufflesom.interpreter.objectstorage.FieldAccessorNode.AbstractReadFieldNode; +import trufflesom.interpreter.objectstorage.ObjectLayout; +import trufflesom.interpreter.objectstorage.StorageLocation; +import trufflesom.vm.SymbolTable; +import trufflesom.vm.VmSettings; +import trufflesom.vm.constants.Nil; +import trufflesom.vmobjects.SObject; + + +public final class LocalFieldStringEqualsNode extends ExpressionNode { + + private final int fieldIdx; + private final String value; + protected final Argument arg; + + @Child private AbstractReadFieldNode readFieldNode; + + @CompilationFinal private int state; + + public LocalFieldStringEqualsNode(final int fieldIdx, final Argument arg, + final String value) { + this.fieldIdx = fieldIdx; + this.arg = arg; + this.value = value; + + this.state = 0; + } + + @Override + public Object executeGeneric(final VirtualFrame frame) { + try { + SObject rcvr = (SObject) frame.getArguments()[0]; + return executeEvaluated(frame, rcvr); + } catch (UnexpectedResultException e) { + return e.getResult(); + } + } + + @Override + public Object doPreEvaluated(final VirtualFrame frame, final Object[] args) { + try { + return executeEvaluated(frame, (SObject) args[0]); + } catch (UnexpectedResultException e) { + return e.getResult(); + } + } + + public boolean executeEvaluated(final VirtualFrame frame, final SObject rcvr) + throws UnexpectedResultException { + int currentState = state; + + if (state == 0) { + // uninitialized + CompilerDirectives.transferToInterpreterAndInvalidate(); + final ObjectLayout layout = rcvr.getObjectLayout(); + StorageLocation location = layout.getStorageLocation(fieldIdx); + + readFieldNode = + insert(location.getReadNode(fieldIdx, layout, + FieldAccessorNode.createRead(fieldIdx))); + } + + Object result = readFieldNode.read(rcvr); + + if ((state & 0b1) != 0) { + // we saw a string before + if (result instanceof String) { + return ((String) result).equals(value); + } + } + + if ((state & 0b10) != 0) { + // we saw a nil before + if (result == Nil.nilObject) { + return false; + } + } + + CompilerDirectives.transferToInterpreterAndInvalidate(); + return specialize(frame, result, currentState); + } + + @Override + public boolean executeBoolean(final VirtualFrame frame) throws UnexpectedResultException { + SObject rcvr = (SObject) frame.getArguments()[0]; + + return executeEvaluated(frame, rcvr); + } + + private boolean specialize(final VirtualFrame frame, final Object result, + final int currentState) throws UnexpectedResultException { + if (result instanceof String) { + state = currentState | 0b1; + return value.equals(result); + } + + if (result == Nil.nilObject) { + state = currentState | 0b10; + return false; + } + + Object sendResult = + makeGenericSend(result).doPreEvaluated(frame, new Object[] {result, value}); + if (sendResult instanceof Boolean) { + return (Boolean) sendResult; + } + throw new UnexpectedResultException(sendResult); + } + + public GenericMessageSendNode makeGenericSend( + @SuppressWarnings("unused") final Object receiver) { + GenericMessageSendNode send = + MessageSendNode.createGeneric(SymbolTable.symbolFor("="), + new ExpressionNode[] {new FieldReadNode(new LocalArgumentReadNode(arg), fieldIdx), + new GenericLiteralNode(value)}, + sourceCoord); + + if (VmSettings.UseAstInterp) { + replace(send); + send.notifyDispatchInserted(); + return send; + } + + assert getParent() instanceof BytecodeLoopNode : "This node was expected to be a direct child of a `BytecodeLoopNode`."; + throw new RespecializeException(send); + } + + @Override + public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { + ScopeElement se = inliner.getAdaptedVar(arg); + if (se.var != arg || se.contextLevel < 0) { + Node newNode; + if (se.contextLevel == 0) { + newNode = + new LocalFieldStringEqualsNode(fieldIdx, (Argument) se.var, value).initialize( + fieldIdx); + } else { + newNode = new NonLocalFieldStringEqualsNode(fieldIdx, (Argument) se.var, + se.contextLevel, value).initialize(fieldIdx); + } + + replace(newNode); + } else { + assert 0 == se.contextLevel; + } + } +} diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalFieldStringEqualsNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalFieldStringEqualsNode.java new file mode 100644 index 000000000..acaba87ef --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/NonLocalFieldStringEqualsNode.java @@ -0,0 +1,160 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.CompilerDirectives.CompilationFinal; +import com.oracle.truffle.api.frame.VirtualFrame; +import com.oracle.truffle.api.nodes.Node; +import com.oracle.truffle.api.nodes.UnexpectedResultException; + +import trufflesom.bdt.inlining.ScopeAdaptationVisitor; +import trufflesom.bdt.inlining.ScopeAdaptationVisitor.ScopeElement; +import trufflesom.compiler.Variable.Argument; +import trufflesom.interpreter.bc.RespecializeException; +import trufflesom.interpreter.nodes.ArgumentReadNode.NonLocalArgumentReadNode; +import trufflesom.interpreter.nodes.ContextualNode; +import trufflesom.interpreter.nodes.ExpressionNode; +import trufflesom.interpreter.nodes.FieldNode.FieldReadNode; +import trufflesom.interpreter.nodes.GenericMessageSendNode; +import trufflesom.interpreter.nodes.MessageSendNode; +import trufflesom.interpreter.nodes.bc.BytecodeLoopNode; +import trufflesom.interpreter.nodes.literals.GenericLiteralNode; +import trufflesom.interpreter.objectstorage.FieldAccessorNode; +import trufflesom.interpreter.objectstorage.FieldAccessorNode.AbstractReadFieldNode; +import trufflesom.interpreter.objectstorage.ObjectLayout; +import trufflesom.interpreter.objectstorage.StorageLocation; +import trufflesom.vm.SymbolTable; +import trufflesom.vm.VmSettings; +import trufflesom.vm.constants.Nil; +import trufflesom.vmobjects.SObject; + + +public class NonLocalFieldStringEqualsNode extends ContextualNode { + + private final int fieldIdx; + private final String value; + protected final Argument arg; + + @Child private AbstractReadFieldNode readFieldNode; + + @CompilationFinal private int state; + + public NonLocalFieldStringEqualsNode(final int fieldIdx, final Argument arg, + final int contextLevel, final String value) { + super(contextLevel); + this.fieldIdx = fieldIdx; + this.arg = arg; + this.value = value; + + this.state = 0; + } + + @Override + public Object executeGeneric(final VirtualFrame frame) { + try { + SObject rcvr = (SObject) determineContext(frame).getArguments()[0]; + return executeEvaluated(frame, rcvr); + } catch (UnexpectedResultException e) { + return e.getResult(); + } + } + + public boolean executeEvaluated(final VirtualFrame frame, final SObject rcvr) + throws UnexpectedResultException { + int currentState = state; + + if (state == 0) { + // uninitialized + CompilerDirectives.transferToInterpreterAndInvalidate(); + final ObjectLayout layout = rcvr.getObjectLayout(); + StorageLocation location = layout.getStorageLocation(fieldIdx); + + readFieldNode = + insert(location.getReadNode(fieldIdx, layout, + FieldAccessorNode.createRead(fieldIdx))); + } + + Object result = readFieldNode.read(rcvr); + + if ((state & 0b1) != 0) { + // we saw a string before + if (result instanceof String) { + return ((String) result).equals(value); + } + } + + if ((state & 0b10) != 0) { + // we saw a nil before + if (result == Nil.nilObject) { + return false; + } + } + + CompilerDirectives.transferToInterpreterAndInvalidate(); + return specialize(frame, result, currentState); + } + + @Override + public boolean executeBoolean(final VirtualFrame frame) throws UnexpectedResultException { + SObject rcvr = (SObject) determineContext(frame).getArguments()[0]; + + return executeEvaluated(frame, rcvr); + } + + private boolean specialize(final VirtualFrame frame, final Object result, + final int currentState) throws UnexpectedResultException { + if (result instanceof String) { + state = currentState | 0b1; + return value.equals(result); + } + + if (result == Nil.nilObject) { + state = currentState | 0b10; + } + + Object sendResult = + makeGenericSend(result).doPreEvaluated(frame, new Object[] {result, value}); + if (sendResult instanceof Boolean) { + return (Boolean) sendResult; + } + throw new UnexpectedResultException(sendResult); + } + + public final GenericMessageSendNode makeGenericSend( + @SuppressWarnings("unused") final Object receiver) { + GenericMessageSendNode send = + MessageSendNode.createGeneric(SymbolTable.symbolFor("="), + new ExpressionNode[] { + new FieldReadNode(new NonLocalArgumentReadNode(arg, contextLevel), fieldIdx), + new GenericLiteralNode(value)}, + sourceCoord); + + if (VmSettings.UseAstInterp) { + replace(send); + send.notifyDispatchInserted(); + return send; + } + + assert getParent() instanceof BytecodeLoopNode : "This node was expected to be a direct child of a `BytecodeLoopNode`."; + throw new RespecializeException(send); + } + + @Override + public void replaceAfterScopeChange(final ScopeAdaptationVisitor inliner) { + ScopeElement se = inliner.getAdaptedVar(arg); + if (se.var != arg || se.contextLevel < contextLevel) { + Node newNode; + if (se.contextLevel == 0) { + newNode = + new LocalFieldStringEqualsNode(fieldIdx, (Argument) se.var, value).initialize( + fieldIdx); + } else { + newNode = new NonLocalFieldStringEqualsNode(fieldIdx, (Argument) se.var, + se.contextLevel, value).initialize(fieldIdx); + } + + replace(newNode); + } else { + assert contextLevel == se.contextLevel; + } + } +} diff --git a/src/trufflesom/src/trufflesom/interpreter/supernodes/StringEqualsNode.java b/src/trufflesom/src/trufflesom/interpreter/supernodes/StringEqualsNode.java new file mode 100644 index 000000000..11f8b3fcf --- /dev/null +++ b/src/trufflesom/src/trufflesom/interpreter/supernodes/StringEqualsNode.java @@ -0,0 +1,73 @@ +package trufflesom.interpreter.supernodes; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.frame.VirtualFrame; + +import trufflesom.interpreter.bc.RespecializeException; +import trufflesom.interpreter.nodes.ExpressionNode; +import trufflesom.interpreter.nodes.GenericMessageSendNode; +import trufflesom.interpreter.nodes.MessageSendNode; +import trufflesom.interpreter.nodes.bc.BytecodeLoopNode; +import trufflesom.interpreter.nodes.literals.GenericLiteralNode; +import trufflesom.interpreter.nodes.nary.UnaryExpressionNode; +import trufflesom.vm.SymbolTable; +import trufflesom.vm.VmSettings; +import trufflesom.vm.constants.Nil; +import trufflesom.vmobjects.SSymbol; + + +public abstract class StringEqualsNode extends UnaryExpressionNode { + private final String value; + + protected static final Object nil = Nil.nilObject; + + protected StringEqualsNode(final String value) { + this.value = value; + } + + @Override + public abstract ExpressionNode getReceiver(); + + @Specialization + public final boolean doString(final String rcvr) { + return value.equals(rcvr); + } + + @Specialization(guards = "rcvr == nil") + public static final boolean doNil(@SuppressWarnings("unused") final Object rcvr) { + return false; + } + + @Fallback + public final Object makeGenericSend(final VirtualFrame frame, + final Object receiver) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + return makeGenericSend(SymbolTable.symbolFor("=")).doPreEvaluated(frame, + new Object[] {receiver, value}); + } + + @Override + protected GenericMessageSendNode makeGenericSend(final SSymbol selector) { + CompilerDirectives.transferToInterpreterAndInvalidate(); + ExpressionNode[] children; + if (VmSettings.UseAstInterp) { + children = new ExpressionNode[] {getReceiver(), new GenericLiteralNode(value)}; + } else { + children = null; + } + + GenericMessageSendNode send = + MessageSendNode.createGeneric(selector, children, sourceCoord); + + if (VmSettings.UseAstInterp) { + replace(send); + send.notifyDispatchInserted(); + return send; + } + + assert getParent() instanceof BytecodeLoopNode : "This node was expected to be a direct child of a `BytecodeLoopNode`."; + throw new RespecializeException(send); + } +} diff --git a/tests/trufflesom/supernodes/StringEqualsTests.java b/tests/trufflesom/supernodes/StringEqualsTests.java new file mode 100644 index 000000000..d0f234b2e --- /dev/null +++ b/tests/trufflesom/supernodes/StringEqualsTests.java @@ -0,0 +1,74 @@ +package trufflesom.supernodes; + +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertThat; + +import org.junit.Test; + +import trufflesom.interpreter.nodes.ExpressionNode; +import trufflesom.interpreter.nodes.SequenceNode; +import trufflesom.interpreter.nodes.literals.BlockNode; +import trufflesom.interpreter.supernodes.LocalFieldStringEqualsNode; +import trufflesom.interpreter.supernodes.NonLocalFieldStringEqualsNode; +import trufflesom.interpreter.supernodes.StringEqualsNode; +import trufflesom.tests.AstTestSetup; + + +public class StringEqualsTests extends AstTestSetup { + + @SuppressWarnings("unchecked") + private T assertThatMainNodeIs(final String test, final Class expectedNode) { + addField("field"); + SequenceNode seq = (SequenceNode) parseMethod( + "test: arg = ( | var | \n" + test + " )"); + + ExpressionNode testExpr = read(seq, "expressions", 0); + assertThat(testExpr, instanceOf(expectedNode)); + return (T) testExpr; + } + + @SuppressWarnings("unchecked") + private T assertInBlock(final String test, final Class expectedNode) { + addField("field"); + SequenceNode seq = (SequenceNode) parseMethod( + "test: arg = ( | var | \n" + test + " )"); + + BlockNode block = (BlockNode) read(seq, "expressions", 0); + ExpressionNode testExpr = + read(block.getMethod().getInvokable(), "body", ExpressionNode.class); + assertThat(testExpr, instanceOf(expectedNode)); + return (T) testExpr; + } + + @Test + public void testStringEqual() { + assertThatMainNodeIs("field = 'str'", LocalFieldStringEqualsNode.class); + assertThatMainNodeIs("arg = 'str'", StringEqualsNode.class); + assertThatMainNodeIs("var = 'str'", StringEqualsNode.class); + assertThatMainNodeIs("('s' + 'dd') = 'str'", StringEqualsNode.class); + + assertThatMainNodeIs("'str' = field", LocalFieldStringEqualsNode.class); + assertThatMainNodeIs("'str' = arg", StringEqualsNode.class); + assertThatMainNodeIs("'str' = var", StringEqualsNode.class); + assertThatMainNodeIs("'str' = ('s' + 'dd')", StringEqualsNode.class); + } + + @Test + public void testStringEqualInBlock() { + assertInBlock("[ field = 'str' ] ", NonLocalFieldStringEqualsNode.class); + assertInBlock("[ arg = 'str' ]", StringEqualsNode.class); + assertInBlock("[ var = 'str' ]", StringEqualsNode.class); + + assertInBlock("[:a | a = 'str' ]", StringEqualsNode.class); + assertInBlock("[ | v | v = 'str' ]", StringEqualsNode.class); + + assertInBlock("[ ('s' + 'dd') = 'str' ]", StringEqualsNode.class); + + assertInBlock("[ 'str' = field ]", NonLocalFieldStringEqualsNode.class); + assertInBlock("[ 'str' = arg ]", StringEqualsNode.class); + assertInBlock("[ 'str' = var ]", StringEqualsNode.class); + assertInBlock("[:a | 'str' = a ]", StringEqualsNode.class); + assertInBlock("[ | v| 'str' = v ]", StringEqualsNode.class); + assertInBlock("[ 'str' = ('s' + 'dd') ]", StringEqualsNode.class); + } +} diff --git a/tests/trufflesom/tests/AstInliningTests.java b/tests/trufflesom/tests/AstInliningTests.java index 38c360015..ec1627ba2 100644 --- a/tests/trufflesom/tests/AstInliningTests.java +++ b/tests/trufflesom/tests/AstInliningTests.java @@ -33,9 +33,9 @@ import trufflesom.interpreter.nodes.specialized.IfInlinedLiteralNode; import trufflesom.interpreter.nodes.specialized.IfTrueIfFalseInlinedLiteralsNode.FalseIfElseLiteralNode; import trufflesom.interpreter.nodes.specialized.IfTrueIfFalseInlinedLiteralsNode.TrueIfElseLiteralNode; -import trufflesom.interpreter.nodes.specialized.IntIncrementNode; import trufflesom.interpreter.nodes.specialized.IntToDoInlinedLiteralsNode; import trufflesom.interpreter.nodes.specialized.whileloops.WhileInlinedLiteralsNode; +import trufflesom.interpreter.supernodes.IntIncrementNode; import trufflesom.primitives.arithmetic.SubtractionPrim; import trufflesom.primitives.arrays.DoPrim;