Skip to content

Commit 583718f

Browse files
Advaith AnandAdvaith Anand
authored andcommitted
addressed linting errors with precommit
1 parent ae2357b commit 583718f

3 files changed

Lines changed: 201 additions & 160 deletions

File tree

src/kirin/dialects/math/__init__.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,160 @@
1-
"math dialect, modeling functions in python's `math` stdlib"# This file is generated by gen.py
1+
"math dialect, modeling functions in python's `math` stdlib" # This file is generated by gen.py
2+
3+
import math as pymath
4+
5+
from kirin import lowering
26
from kirin.dialects.math.dialect import dialect as dialect
7+
38
from . import stmts as stmts, interp as interp
4-
import math as pymath
9+
510
pi = pymath.pi
611
e = pymath.e
712
tau = pymath.tau
8-
from kirin import lowering
13+
914

1015
@lowering.wraps(stmts.acos)
1116
def acos(x: float) -> float: ...
1217

18+
1319
@lowering.wraps(stmts.asin)
1420
def asin(x: float) -> float: ...
1521

22+
1623
@lowering.wraps(stmts.asinh)
1724
def asinh(x: float) -> float: ...
1825

26+
1927
@lowering.wraps(stmts.atan)
2028
def atan(x: float) -> float: ...
2129

30+
2231
@lowering.wraps(stmts.atan2)
2332
def atan2(y: float, x: float) -> float: ...
2433

34+
2535
@lowering.wraps(stmts.atanh)
2636
def atanh(x: float) -> float: ...
2737

38+
2839
@lowering.wraps(stmts.ceil)
2940
def ceil(x: float) -> int: ...
3041

42+
3143
@lowering.wraps(stmts.copysign)
3244
def copysign(x: float, y: float) -> float: ...
3345

46+
3447
@lowering.wraps(stmts.cos)
3548
def cos(x: float) -> float: ...
3649

50+
3751
@lowering.wraps(stmts.cosh)
3852
def cosh(x: float) -> float: ...
3953

54+
4055
@lowering.wraps(stmts.degrees)
4156
def degrees(x: float) -> float: ...
4257

58+
4359
@lowering.wraps(stmts.erf)
4460
def erf(x: float) -> float: ...
4561

62+
4663
@lowering.wraps(stmts.erfc)
4764
def erfc(x: float) -> float: ...
4865

66+
4967
@lowering.wraps(stmts.exp)
5068
def exp(x: float) -> float: ...
5169

70+
5271
@lowering.wraps(stmts.expm1)
5372
def expm1(x: float) -> float: ...
5473

74+
5575
@lowering.wraps(stmts.fabs)
5676
def fabs(x: float) -> float: ...
5777

78+
5879
@lowering.wraps(stmts.floor)
5980
def floor(x: float) -> int: ...
6081

82+
6183
@lowering.wraps(stmts.fmod)
6284
def fmod(x: float, y: float) -> float: ...
6385

86+
6487
@lowering.wraps(stmts.gamma)
6588
def gamma(x: float) -> float: ...
6689

90+
6791
@lowering.wraps(stmts.isfinite)
6892
def isfinite(x: float) -> bool: ...
6993

94+
7095
@lowering.wraps(stmts.isinf)
7196
def isinf(x: float) -> bool: ...
7297

98+
7399
@lowering.wraps(stmts.isnan)
74100
def isnan(x: float) -> bool: ...
75101

102+
76103
@lowering.wraps(stmts.lgamma)
77104
def lgamma(x: float) -> float: ...
78105

106+
79107
@lowering.wraps(stmts.log)
80108
def log(x: float, base: float) -> float: ...
81109

110+
82111
@lowering.wraps(stmts.log10)
83112
def log10(x: float) -> float: ...
84113

114+
85115
@lowering.wraps(stmts.log1p)
86116
def log1p(x: float) -> float: ...
87117

118+
88119
@lowering.wraps(stmts.log2)
89120
def log2(x: float) -> float: ...
90121

122+
91123
@lowering.wraps(stmts.pow)
92124
def pow(x: float, y: float) -> float: ...
93125

126+
94127
@lowering.wraps(stmts.radians)
95128
def radians(x: float) -> float: ...
96129

130+
97131
@lowering.wraps(stmts.remainder)
98132
def remainder(x: float, y: float) -> float: ...
99133

134+
100135
@lowering.wraps(stmts.sin)
101136
def sin(x: float) -> float: ...
102137

138+
103139
@lowering.wraps(stmts.sinh)
104140
def sinh(x: float) -> float: ...
105141

142+
106143
@lowering.wraps(stmts.sqrt)
107144
def sqrt(x: float) -> float: ...
108145

146+
109147
@lowering.wraps(stmts.tan)
110148
def tan(x: float) -> float: ...
111149

150+
112151
@lowering.wraps(stmts.tanh)
113152
def tanh(x: float) -> float: ...
114153

154+
115155
@lowering.wraps(stmts.trunc)
116156
def trunc(x: float) -> int: ...
117157

158+
118159
@lowering.wraps(stmts.ulp)
119160
def ulp(x: float) -> float: ...
120-

src/kirin/dialects/math/interp.py

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# This file is generated by gen.py
22
import math
3-
from kirin.dialects.math.dialect import dialect
3+
4+
from kirin.interp import Frame, MethodTable, impl
45
from kirin.dialects.math import stmts
5-
from kirin.interp import MethodTable, Frame, impl
6+
from kirin.dialects.math.dialect import dialect
67

78

89
@dialect.register
@@ -13,217 +14,181 @@ def acos(self, interp, frame: Frame, stmt: stmts.acos):
1314
values = frame.get_values(stmt.args)
1415
return (math.acos(values[0]),)
1516

16-
1717
@impl(stmts.asin)
1818
def asin(self, interp, frame: Frame, stmt: stmts.asin):
1919
values = frame.get_values(stmt.args)
2020
return (math.asin(values[0]),)
2121

22-
2322
@impl(stmts.asinh)
2423
def asinh(self, interp, frame: Frame, stmt: stmts.asinh):
2524
values = frame.get_values(stmt.args)
2625
return (math.asinh(values[0]),)
2726

28-
2927
@impl(stmts.atan)
3028
def atan(self, interp, frame: Frame, stmt: stmts.atan):
3129
values = frame.get_values(stmt.args)
3230
return (math.atan(values[0]),)
3331

34-
3532
@impl(stmts.atan2)
3633
def atan2(self, interp, frame: Frame, stmt: stmts.atan2):
3734
values = frame.get_values(stmt.args)
3835
return (math.atan2(values[0], values[1]),)
3936

40-
4137
@impl(stmts.atanh)
4238
def atanh(self, interp, frame: Frame, stmt: stmts.atanh):
4339
values = frame.get_values(stmt.args)
4440
return (math.atanh(values[0]),)
4541

46-
4742
@impl(stmts.ceil)
4843
def ceil(self, interp, frame: Frame, stmt: stmts.ceil):
4944
values = frame.get_values(stmt.args)
5045
return (math.ceil(values[0]),)
5146

52-
5347
@impl(stmts.copysign)
5448
def copysign(self, interp, frame: Frame, stmt: stmts.copysign):
5549
values = frame.get_values(stmt.args)
5650
return (math.copysign(values[0], values[1]),)
5751

58-
5952
@impl(stmts.cos)
6053
def cos(self, interp, frame: Frame, stmt: stmts.cos):
6154
values = frame.get_values(stmt.args)
6255
return (math.cos(values[0]),)
6356

64-
6557
@impl(stmts.cosh)
6658
def cosh(self, interp, frame: Frame, stmt: stmts.cosh):
6759
values = frame.get_values(stmt.args)
6860
return (math.cosh(values[0]),)
6961

70-
7162
@impl(stmts.degrees)
7263
def degrees(self, interp, frame: Frame, stmt: stmts.degrees):
7364
values = frame.get_values(stmt.args)
7465
return (math.degrees(values[0]),)
7566

76-
7767
@impl(stmts.erf)
7868
def erf(self, interp, frame: Frame, stmt: stmts.erf):
7969
values = frame.get_values(stmt.args)
8070
return (math.erf(values[0]),)
8171

82-
8372
@impl(stmts.erfc)
8473
def erfc(self, interp, frame: Frame, stmt: stmts.erfc):
8574
values = frame.get_values(stmt.args)
8675
return (math.erfc(values[0]),)
8776

88-
8977
@impl(stmts.exp)
9078
def exp(self, interp, frame: Frame, stmt: stmts.exp):
9179
values = frame.get_values(stmt.args)
9280
return (math.exp(values[0]),)
9381

94-
9582
@impl(stmts.expm1)
9683
def expm1(self, interp, frame: Frame, stmt: stmts.expm1):
9784
values = frame.get_values(stmt.args)
9885
return (math.expm1(values[0]),)
9986

100-
10187
@impl(stmts.fabs)
10288
def fabs(self, interp, frame: Frame, stmt: stmts.fabs):
10389
values = frame.get_values(stmt.args)
10490
return (math.fabs(values[0]),)
10591

106-
10792
@impl(stmts.floor)
10893
def floor(self, interp, frame: Frame, stmt: stmts.floor):
10994
values = frame.get_values(stmt.args)
11095
return (math.floor(values[0]),)
11196

112-
11397
@impl(stmts.fmod)
11498
def fmod(self, interp, frame: Frame, stmt: stmts.fmod):
11599
values = frame.get_values(stmt.args)
116100
return (math.fmod(values[0], values[1]),)
117101

118-
119102
@impl(stmts.gamma)
120103
def gamma(self, interp, frame: Frame, stmt: stmts.gamma):
121104
values = frame.get_values(stmt.args)
122105
return (math.gamma(values[0]),)
123106

124-
125107
@impl(stmts.isfinite)
126108
def isfinite(self, interp, frame: Frame, stmt: stmts.isfinite):
127109
values = frame.get_values(stmt.args)
128110
return (math.isfinite(values[0]),)
129111

130-
131112
@impl(stmts.isinf)
132113
def isinf(self, interp, frame: Frame, stmt: stmts.isinf):
133114
values = frame.get_values(stmt.args)
134115
return (math.isinf(values[0]),)
135116

136-
137117
@impl(stmts.isnan)
138118
def isnan(self, interp, frame: Frame, stmt: stmts.isnan):
139119
values = frame.get_values(stmt.args)
140120
return (math.isnan(values[0]),)
141121

142-
143122
@impl(stmts.lgamma)
144123
def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma):
145124
values = frame.get_values(stmt.args)
146125
return (math.lgamma(values[0]),)
147126

148-
149127
@impl(stmts.log)
150128
def log(self, interp, frame: Frame, stmt: stmts.log):
151129
values = frame.get_values(stmt.args)
152130
return (math.log(values[0], values[1]),)
153131

154-
155132
@impl(stmts.log10)
156133
def log10(self, interp, frame: Frame, stmt: stmts.log10):
157134
values = frame.get_values(stmt.args)
158135
return (math.log10(values[0]),)
159136

160-
161137
@impl(stmts.log1p)
162138
def log1p(self, interp, frame: Frame, stmt: stmts.log1p):
163139
values = frame.get_values(stmt.args)
164140
return (math.log1p(values[0]),)
165141

166-
167142
@impl(stmts.log2)
168143
def log2(self, interp, frame: Frame, stmt: stmts.log2):
169144
values = frame.get_values(stmt.args)
170145
return (math.log2(values[0]),)
171146

172-
173147
@impl(stmts.pow)
174148
def pow(self, interp, frame: Frame, stmt: stmts.pow):
175149
values = frame.get_values(stmt.args)
176150
return (math.pow(values[0], values[1]),)
177151

178-
179152
@impl(stmts.radians)
180153
def radians(self, interp, frame: Frame, stmt: stmts.radians):
181154
values = frame.get_values(stmt.args)
182155
return (math.radians(values[0]),)
183156

184-
185157
@impl(stmts.remainder)
186158
def remainder(self, interp, frame: Frame, stmt: stmts.remainder):
187159
values = frame.get_values(stmt.args)
188160
return (math.remainder(values[0], values[1]),)
189161

190-
191162
@impl(stmts.sin)
192163
def sin(self, interp, frame: Frame, stmt: stmts.sin):
193164
values = frame.get_values(stmt.args)
194165
return (math.sin(values[0]),)
195166

196-
197167
@impl(stmts.sinh)
198168
def sinh(self, interp, frame: Frame, stmt: stmts.sinh):
199169
values = frame.get_values(stmt.args)
200170
return (math.sinh(values[0]),)
201171

202-
203172
@impl(stmts.sqrt)
204173
def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt):
205174
values = frame.get_values(stmt.args)
206175
return (math.sqrt(values[0]),)
207176

208-
209177
@impl(stmts.tan)
210178
def tan(self, interp, frame: Frame, stmt: stmts.tan):
211179
values = frame.get_values(stmt.args)
212180
return (math.tan(values[0]),)
213181

214-
215182
@impl(stmts.tanh)
216183
def tanh(self, interp, frame: Frame, stmt: stmts.tanh):
217184
values = frame.get_values(stmt.args)
218185
return (math.tanh(values[0]),)
219186

220-
221187
@impl(stmts.trunc)
222188
def trunc(self, interp, frame: Frame, stmt: stmts.trunc):
223189
values = frame.get_values(stmt.args)
224190
return (math.trunc(values[0]),)
225191

226-
227192
@impl(stmts.ulp)
228193
def ulp(self, interp, frame: Frame, stmt: stmts.ulp):
229194
values = frame.get_values(stmt.args)

0 commit comments

Comments
 (0)