From 880db7db262ddbfc628194b29610a5706090d61e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 3 Sep 2024 16:21:51 +0100 Subject: [PATCH 1/8] Single-arg version of range, test_range.py sums in loop --- guppylang/prelude/builtins.py | 41 +++++++++++++++++++++++++++++++-- tests/integration/test_range.py | 28 ++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_range.py diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index d3303323..7e027926 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -873,8 +873,45 @@ def print(x): ... def property(x): ... -@guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) -def range(x): ... +@guppy.struct(builtins) +class Range: + stop: int + + @guppy(builtins) + def __iter__(self: "Range") -> "RangeIter": + return RangeIter(0, self.stop) + + +@guppy.struct(builtins) +class RangeIter: + next: int + stop: int + + @guppy(builtins) + def __iter__(self: "RangeIter") -> "RangeIter": + return self + + @guppy(builtins) + def __hasnext__(self: "RangeIter") -> tuple[bool, "RangeIter"]: + return (self.next < self.stop, self) + + @guppy(builtins) + def __next__(self: "RangeIter") -> tuple[int, "RangeIter"]: + # Fine not to check bounds while we can only be called from inside a `for` loop. + # if self.start >= self.stop: + # raise StopIteration + return (self.next, RangeIter(self.next + 1, self.stop)) + + @guppy(builtins) + def __end__(self: "RangeIter") -> None: + pass + + +@guppy(builtins) +def range(stop: int) -> Range: + """Limited version of python range(). + Only a single argument (stop/limit) is supported.""" + return Range(stop) @guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py new file mode 100644 index 00000000..39a3997f --- /dev/null +++ b/tests/integration/test_range.py @@ -0,0 +1,28 @@ +from guppylang.decorator import guppy +from guppylang.prelude.builtins import nat, range +from guppylang.module import GuppyModule +from tests.util import compile_guppy + +def test_range(validate, run_int_fn): + module = GuppyModule("test_aug_assign_loop") + + @guppy(module) + def main() -> int: + total = 0 + xs = range(5) + for x in xs: + total += 100 + x + return total + + @guppy(module) + def negative() -> int: + total = 0 + xs = range(-3) + for x in xs: + total += 100 + x + return total + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, expected=510) + run_int_fn(compiled, expected=0, fn_name="negative") \ No newline at end of file From b4d42025a53999ab15bc387717280c916ce6e222 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 14:58:40 +0100 Subject: [PATCH 2/8] Fix ipython: getsourcelines, etc., works for guppy builtins --- guppylang/definition/struct.py | 37 +++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index 2d20c37a..14473936 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -231,24 +231,29 @@ def parse_py_class(cls: type) -> ast.ClassDef: # - https://github.com/wandb/weave/pull/1864 if is_running_ipython(): defn = find_ipython_def(cls.__name__) - if defn is None: - raise ValueError(f"Couldn't find source for class `{cls.__name__}`") - annotate_location(defn.node, defn.cell_source, f"<{defn.cell_name}>", 1) - if not isinstance(defn.node, ast.ClassDef): - raise GuppyError("Expected a class definition", defn.node) - return defn.node + if defn is not None: + annotate_location(defn.node, defn.cell_source, f"<{defn.cell_name}>", 1) + if not isinstance(defn.node, ast.ClassDef): + raise GuppyError("Expected a class definition", defn.node) + return defn.node + # inspect.getsourcelines works for classes defined in the guppy stdlib/builtins + try: + source_lines, line_offset = inspect.getsourcelines(cls) + except OSError as e: + # Not a guppy builtin + raise ValueError(f"Couldn't find source for class `{cls.__name__}`") from e else: source_lines, line_offset = inspect.getsourcelines(cls) - source = "".join(source_lines) # Lines already have trailing \n's - source = textwrap.dedent(source) - cls_ast = ast.parse(source).body[0] - file = inspect.getsourcefile(cls) - if file is None: - raise GuppyError("Couldn't determine source file for class") - annotate_location(cls_ast, source, file, line_offset) - if not isinstance(cls_ast, ast.ClassDef): - raise GuppyError("Expected a class definition", cls_ast) - return cls_ast + source = "".join(source_lines) # Lines already have trailing \n's + source = textwrap.dedent(source) + cls_ast = ast.parse(source).body[0] + file = inspect.getsourcefile(cls) + if file is None: + raise GuppyError("Couldn't determine source file for class") + annotate_location(cls_ast, source, file, line_offset) + if not isinstance(cls_ast, ast.ClassDef): + raise GuppyError("Expected a class definition", cls_ast) + return cls_ast def try_parse_generic_base(node: ast.expr) -> list[ast.expr] | None: From b663e8c5144a21c6e09aea2a7b1c144af7e3c4ab Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 15:02:19 +0100 Subject: [PATCH 3/8] Fix test_func_name, stdlib is there too --- tests/integration/test_basic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index 01e90b76..e3920159 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -68,11 +68,12 @@ def test_func_def_name(): def func_name() -> None: return - [def_op] = [ - data.op for n, data in func_name.nodes() if isinstance(data.op, ops.FuncDefn) + # Note that while we don't have Hugr linking, and are compiling the stdlib into every Hugr, + # the entire stdlib (even unused parts) will show up here too. + func_defn_names = [ + data.op.name for n, data in func_name.nodes() if isinstance(data.op, ops.FuncDefn) ] - assert isinstance(def_op, ops.FuncDefn) - assert def_op.name == "func_name" + assert "func_name" in func_defn_names def test_func_decl_name(): From 616795e704389599224e8bfdde308a6ecac3824a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 15:19:38 +0100 Subject: [PATCH 4/8] XFAIL the various broken tests --- tests/integration/test_array.py | 2 ++ tests/integration/test_basic.py | 11 ++++++----- tests/integration/test_extern.py | 3 +++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 49ac3fc7..acd79411 100644 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -1,3 +1,4 @@ +import pytest from hugr import ops from hugr.std.int import IntVal @@ -7,6 +8,7 @@ from tests.util import compile_guppy +@pytest.mark.xfail(reason="hugr-includes-whole-stdlib") def test_len(validate): module = GuppyModule("test") diff --git a/tests/integration/test_basic.py b/tests/integration/test_basic.py index e3920159..e9f8fbf3 100644 --- a/tests/integration/test_basic.py +++ b/tests/integration/test_basic.py @@ -1,3 +1,4 @@ +import pytest from hugr import ops from guppylang.decorator import guppy @@ -63,17 +64,17 @@ def foo(x: bool) -> bool: validate(foo) +@pytest.mark.xfail(reason="hugr-includes-whole-stdlib") def test_func_def_name(): @compile_guppy def func_name() -> None: return - # Note that while we don't have Hugr linking, and are compiling the stdlib into every Hugr, - # the entire stdlib (even unused parts) will show up here too. - func_defn_names = [ - data.op.name for n, data in func_name.nodes() if isinstance(data.op, ops.FuncDefn) + [def_op] = [ + data.op for n, data in func_name.nodes() if isinstance(data.op, ops.FuncDefn) ] - assert "func_name" in func_defn_names + assert isinstance(def_op, ops.FuncDefn) + assert def_op.name == "func_name" def test_func_decl_name(): diff --git a/tests/integration/test_extern.py b/tests/integration/test_extern.py index c9591d5a..3bd2985c 100644 --- a/tests/integration/test_extern.py +++ b/tests/integration/test_extern.py @@ -1,9 +1,11 @@ +import pytest from hugr import ops, val from guppylang.decorator import guppy from guppylang.module import GuppyModule +@pytest.mark.xfail(reason="hugr-includes-whole-stdlib") def test_extern_float(validate): module = GuppyModule("module") @@ -21,6 +23,7 @@ def main() -> float: assert c.val.val["symbol"] == "ext" +@pytest.mark.xfail(reason="hugr-includes-whole-stdlib") def test_extern_alt_symbol(validate): module = GuppyModule("module") From 3b20af87f46230b1ee7ec8d42d6b6f962e80ba0b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 15:23:40 +0100 Subject: [PATCH 5/8] typing.overload to reduce errors on @guppy(builtin) --- guppylang/decorator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 3d5fd6ef..24b40f76 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from pathlib import Path from types import ModuleType -from typing import Any, TypeVar +from typing import Any, TypeVar, overload from hugr import Hugr, ops from hugr import tys as ht @@ -61,6 +61,12 @@ class _Guppy: def __init__(self) -> None: self._modules = {} + @overload + def __call__(self, arg: PyFunc) -> RawFunctionDef: ... + + @overload + def __call__(self, arg: GuppyModule) -> FuncDefDecorator: ... + @pretty_errors def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionDef: """Decorator to annotate Python functions as Guppy code. From e994d9e223c462a058b9b1636d06b0647bc42bc0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 15:25:26 +0100 Subject: [PATCH 6/8] type:ignore[call-arg] for Range(Iter)-constructors --- guppylang/prelude/builtins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 7e027926..3517f2ed 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -879,7 +879,7 @@ class Range: @guppy(builtins) def __iter__(self: "Range") -> "RangeIter": - return RangeIter(0, self.stop) + return RangeIter(0, self.stop) # type: ignore[call-arg] @guppy.struct(builtins) @@ -900,7 +900,7 @@ def __next__(self: "RangeIter") -> tuple[int, "RangeIter"]: # Fine not to check bounds while we can only be called from inside a `for` loop. # if self.start >= self.stop: # raise StopIteration - return (self.next, RangeIter(self.next + 1, self.stop)) + return (self.next, RangeIter(self.next + 1, self.stop)) # type: ignore[call-arg] @guppy(builtins) def __end__(self: "RangeIter") -> None: @@ -911,7 +911,7 @@ def __end__(self: "RangeIter") -> None: def range(stop: int) -> Range: """Limited version of python range(). Only a single argument (stop/limit) is supported.""" - return Range(stop) + return Range(stop) # type: ignore[call-arg] @guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) From 6d40278e9cbd03de3123f28c038bc344ee5a8230 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 17:18:59 +0100 Subject: [PATCH 7/8] Test tidies --- tests/integration/test_range.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py index 39a3997f..32d33a61 100644 --- a/tests/integration/test_range.py +++ b/tests/integration/test_range.py @@ -4,25 +4,23 @@ from tests.util import compile_guppy def test_range(validate, run_int_fn): - module = GuppyModule("test_aug_assign_loop") + module = GuppyModule("test_range") @guppy(module) def main() -> int: total = 0 - xs = range(5) - for x in xs: - total += 100 + x + for x in range(5): + total += x + 100 # Make the initial 0 obvious return total @guppy(module) def negative() -> int: total = 0 - xs = range(-3) - for x in xs: + for x in range(-3): total += 100 + x return total compiled = module.compile() validate(compiled) run_int_fn(compiled, expected=510) - run_int_fn(compiled, expected=0, fn_name="negative") \ No newline at end of file + run_int_fn(compiled, expected=0, fn_name="negative") From 7f72debfeb868b02eb7806b335c8d4fb6d125e4e Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 4 Sep 2024 17:24:32 +0100 Subject: [PATCH 8/8] RangeChecker/RangeCompiler to support 1/2/3 args --- guppylang/prelude/_internal/checker.py | 21 +++++++++ guppylang/prelude/_internal/compiler.py | 34 +++++++++++++++ guppylang/prelude/builtins.py | 22 ++++++---- tests/integration/test_range.py | 57 ++++++++++++++++++++++++- 4 files changed, 123 insertions(+), 11 deletions(-) diff --git a/guppylang/prelude/_internal/checker.py b/guppylang/prelude/_internal/checker.py index a7bc400a..e2c7b84e 100644 --- a/guppylang/prelude/_internal/checker.py +++ b/guppylang/prelude/_internal/checker.py @@ -75,6 +75,27 @@ def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: return expr, ty +class RangeChecker(DefaultCallChecker): + """Provides supports for optional args of "range" function + (where a single argument goes to the *second* parameter)""" + + @staticmethod + def _mutate_args(args: list[ast.expr]) -> list[ast.expr]: + if len(args) == 1: + # provided is "stop" index; "start" index of 0 goes first + args.insert(0, ast.Constant(value=0)) + if len(args) == 2: + # missing step - default 1 + args.append(ast.Constant(1)) + return args + + def check(self, args: list[ast.expr], ty: Type) -> tuple[ast.expr, Subst]: + return super().check(self._mutate_args(args), ty) + + def synthesize(self, args: list[ast.expr]) -> tuple[ast.expr, Type]: + return super().synthesize(self._mutate_args(args)) + + class FailingChecker(CustomCallChecker): """Call checker for Python functions that are not available in Guppy. diff --git a/guppylang/prelude/_internal/compiler.py b/guppylang/prelude/_internal/compiler.py index 27fb9be8..320b4656 100644 --- a/guppylang/prelude/_internal/compiler.py +++ b/guppylang/prelude/_internal/compiler.py @@ -438,3 +438,37 @@ def build_array_set( ) array, swapped_elem = iter(builder.add_op(op, array, idx, elem)) return array, swapped_elem + + +class RangeCompiler(CustomCallCompiler): + def compile(self, args: list[Wire]) -> list[Wire]: + from typing import cast + + from guppylang.compiler.expr_compiler import python_value_to_hugr + from guppylang.definition.custom import CustomFunctionDef + + # Find the builtins module. Brute-force search! + # (But we hope this to be quick as we expect most entries to be in builtins) + builtins = next( + def_id.module + for def_id in self.globals + if def_id.module is not None and def_id.module.name == "builtins" + ) + id_Range = builtins._globals["Range"].id + id_new = builtins._globals.impls[id_Range]["__new__"] + ty = NumericType(NumericType.Kind.Int) + if len(args) == 1: + const0 = python_value_to_hugr(0, ty) + assert const0 is not None + args.insert(0, self.builder.load(const0)) + if len(args) == 2: + const1 = python_value_to_hugr(1, ty) + assert const1 is not None + args.append(self.builder.load(const1)) + # python would raise ValueError if step == 0, we have no way to indicate here + # so will produce a Range that just repeats its first argument + call_rets = cast(CustomFunctionDef, self.globals[id_new]).compile_call( + args, [], self.dfg, self.globals, self.node + ) + assert len(call_rets.inout_returns) == 0 + return call_rets.regular_returns diff --git a/guppylang/prelude/builtins.py b/guppylang/prelude/builtins.py index 21db3ad7..07e3cdb2 100644 --- a/guppylang/prelude/builtins.py +++ b/guppylang/prelude/builtins.py @@ -15,6 +15,7 @@ DunderChecker, FailingChecker, NewArrayChecker, + RangeChecker, ResultChecker, ReversingChecker, UnsupportedChecker, @@ -29,6 +30,7 @@ IntTruedivCompiler, NatTruedivCompiler, NewArrayCompiler, + RangeCompiler, ) from guppylang.prelude._internal.util import ( custom_op, @@ -884,17 +886,20 @@ def property(x): ... @guppy.struct(builtins) class Range: + start: int stop: int + step: int @guppy(builtins) def __iter__(self: "Range") -> "RangeIter": - return RangeIter(0, self.stop) # type: ignore[call-arg] + return RangeIter(self.start, self.stop, self.step) # type:ignore[call-arg] @guppy.struct(builtins) class RangeIter: - next: int + start: int stop: int + step: int @guppy(builtins) def __iter__(self: "RangeIter") -> "RangeIter": @@ -902,25 +907,24 @@ def __iter__(self: "RangeIter") -> "RangeIter": @guppy(builtins) def __hasnext__(self: "RangeIter") -> tuple[bool, "RangeIter"]: - return (self.next < self.stop, self) + res = (self.start < self.stop) if self.step >= 0 else (self.start > self.stop) + return (res, self) @guppy(builtins) def __next__(self: "RangeIter") -> tuple[int, "RangeIter"]: # Fine not to check bounds while we can only be called from inside a `for` loop. # if self.start >= self.stop: # raise StopIteration - return (self.next, RangeIter(self.next + 1, self.stop)) # type: ignore[call-arg] + new_self = RangeIter(self.start + self.step, self.stop, self.step) # type: ignore[call-arg] + return (self.start, new_self) @guppy(builtins) def __end__(self: "RangeIter") -> None: pass -@guppy(builtins) -def range(stop: int) -> Range: - """Limited version of python range(). - Only a single argument (stop/limit) is supported.""" - return Range(stop) # type: ignore[call-arg] +@guppy.custom(builtins, RangeCompiler(), RangeChecker()) +def range(start: int, stop: int, step: int) -> Range: ... @guppy.custom(builtins, checker=UnsupportedChecker(), higher_order_value=False) diff --git a/tests/integration/test_range.py b/tests/integration/test_range.py index 32d33a61..4ef4919f 100644 --- a/tests/integration/test_range.py +++ b/tests/integration/test_range.py @@ -3,8 +3,9 @@ from guppylang.module import GuppyModule from tests.util import compile_guppy -def test_range(validate, run_int_fn): - module = GuppyModule("test_range") + +def test_stop_only(validate, run_int_fn): + module = GuppyModule("test_range1") @guppy(module) def main() -> int: @@ -24,3 +25,55 @@ def negative() -> int: validate(compiled) run_int_fn(compiled, expected=510) run_int_fn(compiled, expected=0, fn_name="negative") + +def test_start_stop(validate, run_int_fn): + module = GuppyModule("test_range2") + + @guppy(module) + def simple() -> int: + total = 0 + for x in range(3, 5): + total += x + return total + + @guppy(module) + def empty() -> int: + total = 0 + for x in range(5, 3): + total += x + 100 + return total + + compiled=module.compile() + validate(compiled) + run_int_fn(compiled, expected=7, fn_name="simple") + run_int_fn(compiled, expected=0, fn_name="empty") + +def test_with_step(validate, run_int_fn): + module =GuppyModule("test_range3") + + @guppy(module) + def evens() -> int: + total = 0 + for x in range(2, 7, 2): + total += x + return total + + @guppy(module) + def negative_step() -> int: + total = 0 + for x in range(5, 3, -1): + total += x + return total + + @guppy(module) + def empty() -> int: + total = 0 + for x in range(3, 5, -1): + total += 100 + x + return total + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, expected=12, fn_name="evens") + run_int_fn(compiled, expected=9, fn_name="negative_step") + run_int_fn(compiled, expected=0, fn_name="empty")