Skip to content

Commit 4ec6994

Browse files
authored
fix type for math dialect (#626)
1 parent de63b92 commit 4ec6994

3 files changed

Lines changed: 8 additions & 6 deletions

File tree

src/kirin/dialects/math/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def atanh(x: float) -> float: ...
3737

3838

3939
@lowering.wraps(stmts.ceil)
40-
def ceil(x: float) -> float: ...
40+
def ceil(x: float) -> int: ...
4141

4242

4343
@lowering.wraps(stmts.copysign)
@@ -77,7 +77,7 @@ def fabs(x: float) -> float: ...
7777

7878

7979
@lowering.wraps(stmts.floor)
80-
def floor(x: float) -> float: ...
80+
def floor(x: float) -> int: ...
8181

8282

8383
@lowering.wraps(stmts.fmod)
@@ -149,7 +149,7 @@ def tanh(x: float) -> float: ...
149149

150150

151151
@lowering.wraps(stmts.trunc)
152-
def trunc(x: float) -> float: ...
152+
def trunc(x: float) -> int: ...
153153

154154

155155
@lowering.wraps(stmts.ulp)

src/kirin/dialects/math/_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ class MathMethodTable(MethodTable):
114114
for name, obj, sig in builtin_math_functions():
115115
if "is" in name:
116116
ret_type = "bool"
117+
elif name in {"trunc", "ceil", "floor"}:
118+
ret_type = "int"
117119
else:
118120
ret_type = "float"
119121
f.write(textwrap.dedent(f"""

src/kirin/dialects/math/stmts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class ceil(ir.Statement):
7272
name = "ceil"
7373
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
7474
x: ir.SSAValue = info.argument(types.Float)
75-
result: ir.ResultValue = info.result(types.Float)
75+
result: ir.ResultValue = info.result(types.Int)
7676

7777

7878
@statement(dialect=dialect)
@@ -173,7 +173,7 @@ class floor(ir.Statement):
173173
name = "floor"
174174
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
175175
x: ir.SSAValue = info.argument(types.Float)
176-
result: ir.ResultValue = info.result(types.Float)
176+
result: ir.ResultValue = info.result(types.Int)
177177

178178

179179
@statement(dialect=dialect)
@@ -356,7 +356,7 @@ class trunc(ir.Statement):
356356
name = "trunc"
357357
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
358358
x: ir.SSAValue = info.argument(types.Float)
359-
result: ir.ResultValue = info.result(types.Float)
359+
result: ir.ResultValue = info.result(types.Int)
360360

361361

362362
@statement(dialect=dialect)

0 commit comments

Comments
 (0)