diff --git a/scrapscript.py b/scrapscript.py index 8cceeb14..f44491ae 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -253,6 +253,8 @@ def read_var(self, first_char: str) -> Token: while self.has_input() and is_identifier_char(c := self.peek_char()): self.read_char() buf += c + if buf == "guard": + return self.make_token(Operator, "guard") return self.make_token(Name, buf) def read_bytes(self) -> Token: @@ -303,6 +305,7 @@ def xp(n: float) -> Prec: "::": lp(2000), "@": rp(1001), "": rp(1000), + "guard": rp(5.5), ">>": lp(14), "<<": lp(14), "^": rp(13), @@ -342,7 +345,7 @@ def xp(n: float) -> Prec: HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values()) -OPER_CHARS = set("".join(PS.keys())) +OPER_CHARS = set(c for c in "".join(PS.keys()) if not c.isalpha()) assert " " not in OPER_CHARS @@ -364,6 +367,17 @@ def parse_assign(tokens: typing.List[Token], p: float = 0) -> "Assign": return assign +def build_match_case(expr: "Object") -> "MatchCase": + if not isinstance(expr, Function): + raise ParseError(f"expected function in match expression {expr!r}") + pattern, body = expr.arg, expr.body + guard = None + if isinstance(pattern, Binop) and pattern.op == BinopKind.GUARD: + guard = pattern.right + pattern = pattern.left + return MatchCase(pattern, guard, body) + + def parse(tokens: typing.List[Token], p: float = 0) -> "Object": if not tokens: raise UnexpectedEOFError("unexpected end of input") @@ -401,15 +415,11 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object": l = Spread() elif token == Operator("|"): expr = parse(tokens, PS["|"].pr) # TODO: make this work for larger arities - if not isinstance(expr, Function): - raise ParseError(f"expected function in match expression {expr!r}") - cases = [MatchCase(expr.arg, expr.body)] + cases = [build_match_case(expr)] while tokens and tokens[0] == Operator("|"): tokens.pop(0) expr = parse(tokens, PS["|"].pr) # TODO: make this work for larger arities - if not isinstance(expr, Function): - raise ParseError(f"expected function in match expression {expr!r}") - cases.append(MatchCase(expr.arg, expr.body)) + cases.append(build_match_case(expr)) l = MatchFunction(cases) elif isinstance(token, LeftParen): if isinstance(tokens[0], RightParen): @@ -679,6 +689,7 @@ class BinopKind(enum.Enum): HASTYPE = auto() PIPE = auto() REVERSE_PIPE = auto() + GUARD = auto() @classmethod def from_str(cls, x: str) -> "BinopKind": @@ -705,6 +716,7 @@ def from_str(cls, x: str) -> "BinopKind": ":": cls.HASTYPE, "|>": cls.PIPE, "<|": cls.REVERSE_PIPE, + "guard": cls.GUARD, }[x] @classmethod @@ -731,6 +743,7 @@ def to_str(cls, binop_kind: "BinopKind") -> str: cls.HASTYPE: ":", cls.PIPE: "|>", cls.REVERSE_PIPE: "<|", + cls.GUARD: "guard", }[binop_kind] @@ -868,6 +881,7 @@ def __str__(self) -> str: @dataclass(eq=True, frozen=True, unsafe_hash=True) class MatchCase(Object): pattern: Object + guard: Optional[Object] body: Object def __str__(self) -> str: @@ -1259,6 +1273,8 @@ def eval_exp(env: Env, exp: Object) -> Object: m = match(arg, case.pattern) if m is None: continue + if case.guard is not None and eval_exp({**env, **m}, case.guard) != Symbol("true"): + continue return eval_exp({**callee.env, **m}, case.body) raise MatchError("no matching cases") else: @@ -2189,7 +2205,7 @@ def test_parse_match_no_cases_raises_parse_error(self) -> None: def test_parse_match_one_case(self) -> None: self.assertEqual( parse([Operator("|"), IntLit(1), Operator("->"), IntLit(2)]), - MatchFunction([MatchCase(Int(1), Int(2))]), + MatchFunction([MatchCase(Int(1), None, Int(2))]), ) def test_parse_match_two_cases(self) -> None: @@ -2208,8 +2224,8 @@ def test_parse_match_two_cases(self) -> None: ), MatchFunction( [ - MatchCase(Int(1), Int(2)), - MatchCase(Int(2), Int(3)), + MatchCase(Int(1), None, Int(2)), + MatchCase(Int(2), None, Int(3)), ] ), ) @@ -2328,6 +2344,29 @@ def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None: def test_parse_symbol_returns_symbol(self) -> None: self.assertEqual(parse([SymbolToken("abc")]), Symbol("abc")) + def test_parse_guard(self) -> None: + self.assertEqual( + parse(tokenize("| x guard y -> x")), + MatchFunction([MatchCase(Var("x"), Var("y"), Var("x"))]), + ) + + def test_parse_guard_exp(self) -> None: + self.assertEqual( + parse(tokenize("| x guard x==1 -> x")), + MatchFunction([MatchCase(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1)), Var("x"))]), + ) + + def test_parse_multiple_guards(self) -> None: + self.assertEqual( + parse(tokenize("| x guard y -> x | a guard b -> 1")), + MatchFunction( + [ + MatchCase(Var("x"), Var("y"), Var("x")), + MatchCase(Var("a"), Var("b"), Int(1)), + ] + ), + ) + class MatchTests(unittest.TestCase): def test_match_with_equal_ints_returns_empty_dict(self) -> None: @@ -2504,7 +2543,8 @@ def test_parse_match_with_left_apply(self) -> None: ) ast = parse(tokens) self.assertEqual( - ast, MatchFunction([MatchCase(Var("a"), Apply(Var("b"), Var("c"))), MatchCase(Var("d"), Var("e"))]) + ast, + MatchFunction([MatchCase(Var("a"), None, Apply(Var("b"), Var("c"))), MatchCase(Var("d"), None, Var("e"))]), ) def test_parse_match_with_right_apply(self) -> None: @@ -2518,9 +2558,10 @@ def test_parse_match_with_right_apply(self) -> None: ast, MatchFunction( [ - MatchCase(Int(1), Int(19)), + MatchCase(Int(1), None, Int(19)), MatchCase( Var("a"), + None, Apply( Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))), Var("a"), @@ -2875,26 +2916,29 @@ def test_match_no_cases_raises_match_error(self) -> None: eval_exp({}, exp) def test_match_int_with_equal_int_matches(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(1)) + exp = Apply(MatchFunction([MatchCase(pattern=Int(1), guard=None, body=Int(2))]), Int(1)) self.assertEqual(eval_exp({}, exp), Int(2)) def test_match_int_with_inequal_int_raises_match_error(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(3)) + exp = Apply(MatchFunction([MatchCase(pattern=Int(1), guard=None, body=Int(2))]), Int(3)) with self.assertRaisesRegex(MatchError, "no matching cases"): eval_exp({}, exp) def test_match_string_with_equal_string_matches(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("a")) + exp = Apply(MatchFunction([MatchCase(pattern=String("a"), guard=None, body=String("b"))]), String("a")) self.assertEqual(eval_exp({}, exp), String("b")) def test_match_string_with_inequal_string_raises_match_error(self) -> None: - exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("c")) + exp = Apply(MatchFunction([MatchCase(pattern=String("a"), guard=None, body=String("b"))]), String("c")) with self.assertRaisesRegex(MatchError, "no matching cases"): eval_exp({}, exp) def test_match_falls_through_to_next(self) -> None: exp = Apply( - MatchFunction([MatchCase(pattern=Int(3), body=Int(4)), MatchCase(pattern=Int(1), body=Int(2))]), Int(1) + MatchFunction( + [MatchCase(pattern=Int(3), guard=None, body=Int(4)), MatchCase(pattern=Int(1), guard=None, body=Int(2))] + ), + Int(1), ) self.assertEqual(eval_exp({}, exp), Int(2)) @@ -2943,7 +2987,7 @@ def test_eval_apply_quote_returns_ast(self) -> None: self.assertIs(eval_exp({}, exp), ast) def test_eval_apply_closure_with_match_function_has_access_to_closure_vars(self) -> None: - ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), Var("x"))])), Int(2)) + ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), None, Var("x"))])), Int(2)) self.assertEqual(eval_exp({}, ast), Int(1)) def test_eval_less_returns_bool(self) -> None: @@ -3207,6 +3251,98 @@ def test_match_var_binds_var(self) -> None: Int(3), ) + def test_match_guard_closure_var(self) -> None: + self.assertEqual( + self._run( + """ + id 1 + . id = + | x guard cond -> "one" + | x -> "idk" + . cond = 2 + """ + ), + String("idk"), + ) + + def test_match_record_guard_pass(self) -> None: + self.assertEqual( + self._run( + """ + id {cond=#true} + . id = + | {cond=cond} guard cond -> "yes" + | x -> "no" + """ + ), + String("yes"), + ) + + def test_match_record_guard_fail(self) -> None: + self.assertEqual( + self._run( + """ + id {cond=#false} + . id = + | {cond=cond} guard cond -> "yes" + | x -> "no" + """ + ), + String("no"), + ) + + def test_match_list_guard_pass(self) -> None: + self.assertEqual( + self._run( + """ + id [#true] + . id = + | [cond] guard cond -> "yes" + | x -> "no" + """ + ), + String("yes"), + ) + + def test_match_list_guard_fail(self) -> None: + self.assertEqual( + self._run( + """ + id [#false] + . id = + | [cond] guard cond -> "yes" + | x -> "no" + """ + ), + String("no"), + ) + + def test_match_guard_pass(self) -> None: + self.assertEqual( + self._run( + """ + id 1 + . id = + | x guard x==1 -> "one" + | x -> "idk" + """ + ), + String("one"), + ) + + def test_match_guard_fail(self) -> None: + self.assertEqual( + self._run( + """ + id 2 + . id = + | x guard x==1 -> "one" + | x -> "idk" + """ + ), + String("idk"), + ) + def test_match_var_binds_first_arm(self) -> None: self.assertEqual( self._run( @@ -3554,38 +3690,39 @@ def test_match_function(self) -> None: self.assertEqual(free_in(exp), {"x", "y"}) def test_match_case_int(self) -> None: - exp = MatchCase(Int(1), Var("x")) + exp = MatchCase(Int(1), None, Var("x")) self.assertEqual(free_in(exp), {"x"}) def test_match_case_var(self) -> None: - exp = MatchCase(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(Var("x"), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_match_case_list(self) -> None: - exp = MatchCase(List([Var("x")]), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(List([Var("x")]), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_match_case_list_spread(self) -> None: - exp = MatchCase(List([Spread()]), Binop(BinopKind.ADD, Var("xs"), Var("y"))) + exp = MatchCase(List([Spread()]), None, Binop(BinopKind.ADD, Var("xs"), Var("y"))) self.assertEqual(free_in(exp), {"xs", "y"}) def test_match_case_list_spread_name(self) -> None: - exp = MatchCase(List([Spread("xs")]), Binop(BinopKind.ADD, Var("xs"), Var("y"))) + exp = MatchCase(List([Spread("xs")]), None, Binop(BinopKind.ADD, Var("xs"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_match_case_record(self) -> None: exp = MatchCase( Record({"x": Int(1), "y": Var("y"), "a": Var("z")}), + None, Binop(BinopKind.ADD, Binop(BinopKind.ADD, Var("x"), Var("y")), Var("z")), ) self.assertEqual(free_in(exp), {"x"}) def test_match_case_record_spread(self) -> None: - exp = MatchCase(Record({"...": Spread()}), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(Record({"...": Spread()}), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"x", "y"}) def test_match_case_record_spread_name(self) -> None: - exp = MatchCase(Record({"...": Spread("x")}), Binop(BinopKind.ADD, Var("x"), Var("y"))) + exp = MatchCase(Record({"...": Spread("x")}), None, Binop(BinopKind.ADD, Var("x"), Var("y"))) self.assertEqual(free_in(exp), {"y"}) def test_apply(self) -> None: @@ -4310,12 +4447,14 @@ def test_pretty_print_envobject(self) -> None: self.assertEqual(str(obj), "EnvObject(keys=dict_keys(['x']))") def test_pretty_print_matchcase(self) -> None: - obj = MatchCase(pattern=Int(1), body=Int(2)) - self.assertEqual(str(obj), "MatchCase(pattern=Int(value=1), body=Int(value=2))") + obj = MatchCase(pattern=Int(1), guard=None, body=Int(2)) + self.assertEqual(str(obj), "MatchCase(pattern=Int(value=1), guard=None, body=Int(value=2))") def test_pretty_print_matchfunction(self) -> None: - obj = MatchFunction([MatchCase(Var("y"), Var("x"))]) - self.assertEqual(str(obj), "MatchFunction(cases=[MatchCase(pattern=Var(name='y'), body=Var(name='x'))])") + obj = MatchFunction([MatchCase(Var("y"), None, Var("x"))]) + self.assertEqual( + str(obj), "MatchFunction(cases=[MatchCase(pattern=Var(name='y'), guard=None, body=Var(name='x'))])" + ) def test_pretty_print_relocation(self) -> None: obj = Relocation("relocate")