Skip to content

Add guard expressions to pattern matching #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: trunk
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 168 additions & 29 deletions scrapscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def read_var(self, first_char: str) -> Token:
while self.has_input() and is_identifier_char(c := self.peek_char()):
self.read_char()
buf += c
if buf == "guard":
return self.make_token(Operator, "guard")
return self.make_token(Name, buf)

def read_bytes(self) -> Token:
Expand Down Expand Up @@ -303,6 +305,7 @@ def xp(n: float) -> Prec:
"::": lp(2000),
"@": rp(1001),
"": rp(1000),
"guard": rp(5.5),
">>": lp(14),
"<<": lp(14),
"^": rp(13),
Expand Down Expand Up @@ -342,7 +345,7 @@ def xp(n: float) -> Prec:
HIGHEST_PREC: float = max(max(p.pl, p.pr) for p in PS.values())


OPER_CHARS = set("".join(PS.keys()))
OPER_CHARS = set(c for c in "".join(PS.keys()) if not c.isalpha())
assert " " not in OPER_CHARS


Expand All @@ -364,6 +367,17 @@ def parse_assign(tokens: typing.List[Token], p: float = 0) -> "Assign":
return assign


def build_match_case(expr: "Object") -> "MatchCase":
if not isinstance(expr, Function):
raise ParseError(f"expected function in match expression {expr!r}")
pattern, body = expr.arg, expr.body
guard = None
if isinstance(pattern, Binop) and pattern.op == BinopKind.GUARD:
guard = pattern.right
pattern = pattern.left
return MatchCase(pattern, guard, body)


def parse(tokens: typing.List[Token], p: float = 0) -> "Object":
if not tokens:
raise UnexpectedEOFError("unexpected end of input")
Expand Down Expand Up @@ -401,15 +415,11 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object":
l = Spread()
elif token == Operator("|"):
expr = parse(tokens, PS["|"].pr) # TODO: make this work for larger arities
if not isinstance(expr, Function):
raise ParseError(f"expected function in match expression {expr!r}")
cases = [MatchCase(expr.arg, expr.body)]
cases = [build_match_case(expr)]
while tokens and tokens[0] == Operator("|"):
tokens.pop(0)
expr = parse(tokens, PS["|"].pr) # TODO: make this work for larger arities
if not isinstance(expr, Function):
raise ParseError(f"expected function in match expression {expr!r}")
cases.append(MatchCase(expr.arg, expr.body))
cases.append(build_match_case(expr))
l = MatchFunction(cases)
elif isinstance(token, LeftParen):
if isinstance(tokens[0], RightParen):
Expand Down Expand Up @@ -679,6 +689,7 @@ class BinopKind(enum.Enum):
HASTYPE = auto()
PIPE = auto()
REVERSE_PIPE = auto()
GUARD = auto()

@classmethod
def from_str(cls, x: str) -> "BinopKind":
Expand All @@ -705,6 +716,7 @@ def from_str(cls, x: str) -> "BinopKind":
":": cls.HASTYPE,
"|>": cls.PIPE,
"<|": cls.REVERSE_PIPE,
"guard": cls.GUARD,
}[x]

@classmethod
Expand All @@ -731,6 +743,7 @@ def to_str(cls, binop_kind: "BinopKind") -> str:
cls.HASTYPE: ":",
cls.PIPE: "|>",
cls.REVERSE_PIPE: "<|",
cls.GUARD: "guard",
}[binop_kind]


Expand Down Expand Up @@ -868,6 +881,7 @@ def __str__(self) -> str:
@dataclass(eq=True, frozen=True, unsafe_hash=True)
class MatchCase(Object):
pattern: Object
guard: Optional[Object]
body: Object

def __str__(self) -> str:
Expand Down Expand Up @@ -1259,6 +1273,8 @@ def eval_exp(env: Env, exp: Object) -> Object:
m = match(arg, case.pattern)
if m is None:
continue
if case.guard is not None and eval_exp({**env, **m}, case.guard) != Symbol("true"):
continue
return eval_exp({**callee.env, **m}, case.body)
raise MatchError("no matching cases")
else:
Expand Down Expand Up @@ -2189,7 +2205,7 @@ def test_parse_match_no_cases_raises_parse_error(self) -> None:
def test_parse_match_one_case(self) -> None:
self.assertEqual(
parse([Operator("|"), IntLit(1), Operator("->"), IntLit(2)]),
MatchFunction([MatchCase(Int(1), Int(2))]),
MatchFunction([MatchCase(Int(1), None, Int(2))]),
)

def test_parse_match_two_cases(self) -> None:
Expand All @@ -2208,8 +2224,8 @@ def test_parse_match_two_cases(self) -> None:
),
MatchFunction(
[
MatchCase(Int(1), Int(2)),
MatchCase(Int(2), Int(3)),
MatchCase(Int(1), None, Int(2)),
MatchCase(Int(2), None, Int(3)),
]
),
)
Expand Down Expand Up @@ -2328,6 +2344,29 @@ def test_parse_record_with_trailing_comma_raises_parse_error(self) -> None:
def test_parse_symbol_returns_symbol(self) -> None:
self.assertEqual(parse([SymbolToken("abc")]), Symbol("abc"))

def test_parse_guard(self) -> None:
self.assertEqual(
parse(tokenize("| x guard y -> x")),
MatchFunction([MatchCase(Var("x"), Var("y"), Var("x"))]),
)

def test_parse_guard_exp(self) -> None:
self.assertEqual(
parse(tokenize("| x guard x==1 -> x")),
MatchFunction([MatchCase(Var("x"), Binop(BinopKind.EQUAL, Var("x"), Int(1)), Var("x"))]),
)

def test_parse_multiple_guards(self) -> None:
self.assertEqual(
parse(tokenize("| x guard y -> x | a guard b -> 1")),
MatchFunction(
[
MatchCase(Var("x"), Var("y"), Var("x")),
MatchCase(Var("a"), Var("b"), Int(1)),
]
),
)


class MatchTests(unittest.TestCase):
def test_match_with_equal_ints_returns_empty_dict(self) -> None:
Expand Down Expand Up @@ -2504,7 +2543,8 @@ def test_parse_match_with_left_apply(self) -> None:
)
ast = parse(tokens)
self.assertEqual(
ast, MatchFunction([MatchCase(Var("a"), Apply(Var("b"), Var("c"))), MatchCase(Var("d"), Var("e"))])
ast,
MatchFunction([MatchCase(Var("a"), None, Apply(Var("b"), Var("c"))), MatchCase(Var("d"), None, Var("e"))]),
)

def test_parse_match_with_right_apply(self) -> None:
Expand All @@ -2518,9 +2558,10 @@ def test_parse_match_with_right_apply(self) -> None:
ast,
MatchFunction(
[
MatchCase(Int(1), Int(19)),
MatchCase(Int(1), None, Int(19)),
MatchCase(
Var("a"),
None,
Apply(
Function(Var("x"), Binop(BinopKind.ADD, Var("x"), Int(1))),
Var("a"),
Expand Down Expand Up @@ -2875,26 +2916,29 @@ def test_match_no_cases_raises_match_error(self) -> None:
eval_exp({}, exp)

def test_match_int_with_equal_int_matches(self) -> None:
exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(1))
exp = Apply(MatchFunction([MatchCase(pattern=Int(1), guard=None, body=Int(2))]), Int(1))
self.assertEqual(eval_exp({}, exp), Int(2))

def test_match_int_with_inequal_int_raises_match_error(self) -> None:
exp = Apply(MatchFunction([MatchCase(pattern=Int(1), body=Int(2))]), Int(3))
exp = Apply(MatchFunction([MatchCase(pattern=Int(1), guard=None, body=Int(2))]), Int(3))
with self.assertRaisesRegex(MatchError, "no matching cases"):
eval_exp({}, exp)

def test_match_string_with_equal_string_matches(self) -> None:
exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("a"))
exp = Apply(MatchFunction([MatchCase(pattern=String("a"), guard=None, body=String("b"))]), String("a"))
self.assertEqual(eval_exp({}, exp), String("b"))

def test_match_string_with_inequal_string_raises_match_error(self) -> None:
exp = Apply(MatchFunction([MatchCase(pattern=String("a"), body=String("b"))]), String("c"))
exp = Apply(MatchFunction([MatchCase(pattern=String("a"), guard=None, body=String("b"))]), String("c"))
with self.assertRaisesRegex(MatchError, "no matching cases"):
eval_exp({}, exp)

def test_match_falls_through_to_next(self) -> None:
exp = Apply(
MatchFunction([MatchCase(pattern=Int(3), body=Int(4)), MatchCase(pattern=Int(1), body=Int(2))]), Int(1)
MatchFunction(
[MatchCase(pattern=Int(3), guard=None, body=Int(4)), MatchCase(pattern=Int(1), guard=None, body=Int(2))]
),
Int(1),
)
self.assertEqual(eval_exp({}, exp), Int(2))

Expand Down Expand Up @@ -2943,7 +2987,7 @@ def test_eval_apply_quote_returns_ast(self) -> None:
self.assertIs(eval_exp({}, exp), ast)

def test_eval_apply_closure_with_match_function_has_access_to_closure_vars(self) -> None:
ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), Var("x"))])), Int(2))
ast = Apply(Closure({"x": Int(1)}, MatchFunction([MatchCase(Var("y"), None, Var("x"))])), Int(2))
self.assertEqual(eval_exp({}, ast), Int(1))

def test_eval_less_returns_bool(self) -> None:
Expand Down Expand Up @@ -3207,6 +3251,98 @@ def test_match_var_binds_var(self) -> None:
Int(3),
)

def test_match_guard_closure_var(self) -> None:
self.assertEqual(
self._run(
"""
id 1
. id =
| x guard cond -> "one"
| x -> "idk"
. cond = 2
"""
),
String("idk"),
)

def test_match_record_guard_pass(self) -> None:
self.assertEqual(
self._run(
"""
id {cond=#true}
. id =
| {cond=cond} guard cond -> "yes"
| x -> "no"
"""
),
String("yes"),
)

def test_match_record_guard_fail(self) -> None:
self.assertEqual(
self._run(
"""
id {cond=#false}
. id =
| {cond=cond} guard cond -> "yes"
| x -> "no"
"""
),
String("no"),
)

def test_match_list_guard_pass(self) -> None:
self.assertEqual(
self._run(
"""
id [#true]
. id =
| [cond] guard cond -> "yes"
| x -> "no"
"""
),
String("yes"),
)

def test_match_list_guard_fail(self) -> None:
self.assertEqual(
self._run(
"""
id [#false]
. id =
| [cond] guard cond -> "yes"
| x -> "no"
"""
),
String("no"),
)

def test_match_guard_pass(self) -> None:
self.assertEqual(
self._run(
"""
id 1
. id =
| x guard x==1 -> "one"
| x -> "idk"
"""
),
String("one"),
)

def test_match_guard_fail(self) -> None:
self.assertEqual(
self._run(
"""
id 2
. id =
| x guard x==1 -> "one"
| x -> "idk"
"""
),
String("idk"),
)

def test_match_var_binds_first_arm(self) -> None:
self.assertEqual(
self._run(
Expand Down Expand Up @@ -3554,38 +3690,39 @@ def test_match_function(self) -> None:
self.assertEqual(free_in(exp), {"x", "y"})

def test_match_case_int(self) -> None:
exp = MatchCase(Int(1), Var("x"))
exp = MatchCase(Int(1), None, Var("x"))
self.assertEqual(free_in(exp), {"x"})

def test_match_case_var(self) -> None:
exp = MatchCase(Var("x"), Binop(BinopKind.ADD, Var("x"), Var("y")))
exp = MatchCase(Var("x"), None, Binop(BinopKind.ADD, Var("x"), Var("y")))
self.assertEqual(free_in(exp), {"y"})

def test_match_case_list(self) -> None:
exp = MatchCase(List([Var("x")]), Binop(BinopKind.ADD, Var("x"), Var("y")))
exp = MatchCase(List([Var("x")]), None, Binop(BinopKind.ADD, Var("x"), Var("y")))
self.assertEqual(free_in(exp), {"y"})

def test_match_case_list_spread(self) -> None:
exp = MatchCase(List([Spread()]), Binop(BinopKind.ADD, Var("xs"), Var("y")))
exp = MatchCase(List([Spread()]), None, Binop(BinopKind.ADD, Var("xs"), Var("y")))
self.assertEqual(free_in(exp), {"xs", "y"})

def test_match_case_list_spread_name(self) -> None:
exp = MatchCase(List([Spread("xs")]), Binop(BinopKind.ADD, Var("xs"), Var("y")))
exp = MatchCase(List([Spread("xs")]), None, Binop(BinopKind.ADD, Var("xs"), Var("y")))
self.assertEqual(free_in(exp), {"y"})

def test_match_case_record(self) -> None:
exp = MatchCase(
Record({"x": Int(1), "y": Var("y"), "a": Var("z")}),
None,
Binop(BinopKind.ADD, Binop(BinopKind.ADD, Var("x"), Var("y")), Var("z")),
)
self.assertEqual(free_in(exp), {"x"})

def test_match_case_record_spread(self) -> None:
exp = MatchCase(Record({"...": Spread()}), Binop(BinopKind.ADD, Var("x"), Var("y")))
exp = MatchCase(Record({"...": Spread()}), None, Binop(BinopKind.ADD, Var("x"), Var("y")))
self.assertEqual(free_in(exp), {"x", "y"})

def test_match_case_record_spread_name(self) -> None:
exp = MatchCase(Record({"...": Spread("x")}), Binop(BinopKind.ADD, Var("x"), Var("y")))
exp = MatchCase(Record({"...": Spread("x")}), None, Binop(BinopKind.ADD, Var("x"), Var("y")))
self.assertEqual(free_in(exp), {"y"})

def test_apply(self) -> None:
Expand Down Expand Up @@ -4310,12 +4447,14 @@ def test_pretty_print_envobject(self) -> None:
self.assertEqual(str(obj), "EnvObject(keys=dict_keys(['x']))")

def test_pretty_print_matchcase(self) -> None:
obj = MatchCase(pattern=Int(1), body=Int(2))
self.assertEqual(str(obj), "MatchCase(pattern=Int(value=1), body=Int(value=2))")
obj = MatchCase(pattern=Int(1), guard=None, body=Int(2))
self.assertEqual(str(obj), "MatchCase(pattern=Int(value=1), guard=None, body=Int(value=2))")

def test_pretty_print_matchfunction(self) -> None:
obj = MatchFunction([MatchCase(Var("y"), Var("x"))])
self.assertEqual(str(obj), "MatchFunction(cases=[MatchCase(pattern=Var(name='y'), body=Var(name='x'))])")
obj = MatchFunction([MatchCase(Var("y"), None, Var("x"))])
self.assertEqual(
str(obj), "MatchFunction(cases=[MatchCase(pattern=Var(name='y'), guard=None, body=Var(name='x'))])"
)

def test_pretty_print_relocation(self) -> None:
obj = Relocation("relocate")
Expand Down