Skip to content

Commit e85a061

Browse files
authored
fix multi assign (#348)
I just realize I misinterpreted what it means when there are multiple targets in `ast.Assign` it actually means the following ```py x = y.z = 1 ``` this also fix: - statement with multiple return values. Allow one create such statements without using a tuple. - named expr, e.g `x := 1` - `ast.AugAssign`, e.g `x += 1`
1 parent 0943deb commit e85a061

2 files changed

Lines changed: 120 additions & 12 deletions

File tree

src/kirin/dialects/py/assign.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,24 +133,42 @@ def lower_Assign(self, state: lowering.State, node: ast.Assign) -> lowering.Resu
133133
stmt.result.name = lhs_name
134134
current_frame.defs[lhs_name] = current_frame.push(stmt).result
135135
case _:
136-
for target, value in zip(node.targets, result.data):
137-
self.assign_item(state, target, value)
136+
for target in node.targets:
137+
self.assign_item(state, target, result)
138138

139139
def lower_AnnAssign(
140140
self, state: lowering.State, node: ast.AnnAssign
141141
) -> lowering.Result:
142142
type_hint = self.get_hint(state, node.annotation)
143143
value = state.lower(node.value).expect_one()
144144
stmt = state.current_frame.push(TypeAssert(got=value, expected=type_hint))
145-
self.assign_item(state, node.target, stmt.result)
145+
self.assign_item_value(state, node.target, stmt.result)
146146

147147
def lower_AugAssign(
148148
self, state: lowering.State, node: ast.AugAssign
149149
) -> lowering.Result:
150-
self.assign_item(state, node.target, state.lower(node.value).expect_one())
150+
match node.target:
151+
case ast.Name(name, ast.Store()):
152+
rhs = ast.Name(name, ast.Load())
153+
case ast.Attribute(obj, attr, ast.Store()):
154+
rhs = ast.Attribute(obj, attr, ast.Load())
155+
case ast.Subscript(obj, slice, ast.Store()):
156+
rhs = ast.Subscript(obj, slice, ast.Load())
157+
self.assign_item_value(
158+
state,
159+
node.target,
160+
state.lower(ast.BinOp(rhs, node.op, node.value)).expect_one(),
161+
)
151162

152-
@staticmethod
153-
def assign_item(state: lowering.State, target, value: ir.SSAValue):
163+
def lower_NamedExpr(
164+
self, state: lowering.State, node: ast.NamedExpr
165+
) -> lowering.Result:
166+
value = state.lower(node.value).expect_one()
167+
self.assign_item_value(state, node.target, value)
168+
return value
169+
170+
@classmethod
171+
def assign_item_value(cls, state: lowering.State, target, value: ir.SSAValue):
154172
current_frame = state.current_frame
155173
match target:
156174
case ast.Name(name, ast.Store()):
@@ -168,9 +186,15 @@ def assign_item(state: lowering.State, target, value: ir.SSAValue):
168186
case _:
169187
raise lowering.BuildError(f"unsupported target {target}")
170188

171-
@staticmethod
172-
def assert_assign_value_type(value: ir.SSAValue, type_hint: types.TypeAttribute):
173-
value_type = value.type.meet(type_hint)
174-
if value_type is value_type.bottom():
175-
raise lowering.BuildError(f"Cannot assign {value.type} to {type_hint}")
176-
return value_type
189+
@classmethod
190+
def assign_item(cls, state: lowering.State, target, result: lowering.State.Result):
191+
match target:
192+
case ast.Tuple(elts, ast.Store()):
193+
if len(elts) != len(result.data):
194+
raise lowering.BuildError(
195+
f"tuple assignment length mismatch: {len(elts)} != {len(result.data)}"
196+
)
197+
for target, value in zip(elts, result.data):
198+
cls.assign_item_value(state, target, value)
199+
case _:
200+
cls.assign_item_value(state, target, result.expect_one())

test/lowering/test_assign.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
3+
from kirin import ir, lowering
4+
from kirin.decl import info, statement
5+
from kirin.prelude import basic_no_opt
6+
from kirin.dialects import cf, py
7+
8+
dialect = ir.Dialect("test")
9+
10+
11+
@statement(dialect=dialect)
12+
class MultiResult(ir.Statement):
13+
traits = frozenset({lowering.FromPythonCall()})
14+
result_a: ir.ResultValue = info.result()
15+
result_b: ir.ResultValue = info.result()
16+
17+
18+
dummy_dialect = basic_no_opt.add(dialect)
19+
20+
21+
def test_multi_result():
22+
@dummy_dialect
23+
def multi_assign():
24+
(x, y) = MultiResult() # type: ignore
25+
return x, y
26+
27+
stmt = multi_assign.callable_region.blocks[0].stmts.at(0)
28+
assert isinstance(stmt, MultiResult)
29+
assert stmt.result_a.name == "x"
30+
assert stmt.result_b.name == "y"
31+
32+
with pytest.raises(lowering.BuildError):
33+
34+
@dummy_dialect
35+
def multi_assign_error():
36+
(x, y, z) = MultiResult() # type: ignore
37+
return x, y, z
38+
39+
40+
def test_chain_assign_setattr():
41+
42+
@dummy_dialect
43+
def chain_assign(y):
44+
x = y.z = 1
45+
return x, y
46+
47+
stmt = chain_assign.callable_region.blocks[0].stmts.at(1)
48+
assert isinstance(stmt, py.assign.SetAttribute)
49+
assert stmt.obj.name == "y"
50+
assert stmt.attr == "z"
51+
assert stmt.value.name == "x"
52+
53+
54+
def test_aug_assign():
55+
@dummy_dialect
56+
def aug_assign(y):
57+
y += 1
58+
return y
59+
60+
y = aug_assign.callable_region.blocks[0].args[1]
61+
const = aug_assign.callable_region.blocks[0].stmts.at(0)
62+
assert isinstance(const, py.Constant)
63+
assert const.value.unwrap() == 1
64+
add = aug_assign.callable_region.blocks[0].stmts.at(1)
65+
assert isinstance(add, py.binop.Add)
66+
assert add.lhs is y
67+
assert add.rhs is const.result
68+
69+
70+
def test_named_expr():
71+
72+
@dummy_dialect
73+
def named_expr(y):
74+
if y := y + 1:
75+
return y
76+
return y
77+
78+
stmt = named_expr.callable_region.blocks[0].stmts.at(1)
79+
y = named_expr.callable_region.blocks[0].args[1]
80+
assert isinstance(stmt, py.binop.Add)
81+
assert stmt.lhs is y
82+
br = named_expr.callable_region.blocks[0].stmts.at(2)
83+
assert isinstance(br, cf.ConditionalBranch)
84+
assert stmt.result is br.cond

0 commit comments

Comments
 (0)