@@ -49,9 +49,9 @@ impl EgglogOp for RowAdd {
4949 (= ?sa (Add ?shape ?a ?a_stride ?b ?b_stride ?out_stride))
5050 (= ?row_width (nth_from_end ?shape 0))
5151 ; assert the row is contiguous
52- (= (MIter ) (nth_from_end ?a_stride 0))
53- (= (MIter ) (nth_from_end ?b_stride 0))
54- (= (MIter ) (nth_from_end ?out_stride 0))
52+ (= (MNum 1 ) (nth_from_end ?a_stride 0))
53+ (= (MNum 1 ) (nth_from_end ?b_stride 0))
54+ (= (MNum 1 ) (nth_from_end ?out_stride 0))
5555 ;(= (F32) (dtype ?a))
5656 )
5757 (
@@ -167,34 +167,34 @@ impl EgglogOp for RowSwishMul {
167167 (= ?sigmoid (Sigmoid
168168 (ECons ?batch (ECons ?width (ENil)))
169169 ?self
170- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
171- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
170+ (ECons ?width (ECons (MNum 1 ) (ENil)))
171+ (ECons ?width (ECons (MNum 1 ) (ENil)))
172172 ))
173173 (= ?swish (Mul
174174 (ECons ?batch (ECons ?width (ENil)))
175175 ?self
176- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
176+ (ECons ?width (ECons (MNum 1 ) (ENil)))
177177 ?sigmoid
178- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
179- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
178+ (ECons ?width (ECons (MNum 1 ) (ENil)))
179+ (ECons ?width (ECons (MNum 1 ) (ENil)))
180180 ))
181181 (= ?swishmul (Mul
182182 (ECons ?batch (ECons ?width (ENil)))
183183 ?swish
184- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
184+ (ECons ?width (ECons (MNum 1 ) (ENil)))
185185 ?other
186- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
187- (ECons (MMul (MIter) ?width) (ECons (MIter ) (ENil)))
186+ (ECons ?width (ECons (MNum 1 ) (ENil)))
187+ (ECons ?width (ECons (MNum 1 ) (ENil)))
188188 ))
189189 ;(= (F32) (dtype ?self))
190190 )
191191 (
192192 (let ?rsm (RowSwishMul
193193 (ECons ?batch (ENil))
194194 ?self
195- (ECons (MMul (MIter) ?width) (ENil))
195+ (ECons ?width (ENil))
196196 ?other
197- (ECons (MMul (MIter) ?width) (ENil))
197+ (ECons ?width (ENil))
198198 ?width
199199 ))
200200 (union ?swishmul ?rsm)
@@ -312,52 +312,52 @@ impl EgglogOp for RowRMSNorm {
312312 (ECons ?batch (ENil))
313313 ?width
314314 ?square
315- (ECons (MMul (MIter) ?width) (ENil))
316- (MIter )
317- (ECons (MIter ) (ENil))
315+ (ECons ?width (ENil))
316+ (MNum 1 )
317+ (ECons (MNum 1 ) (ENil))
318318 )
319319 )
320320 (= ?inv_div_factor
321321 (Recip (ECons ?batch (ENil)) (Cast (Iota ?width (MNum 1)) (F32))
322322 (ECons (MNum 0) (ENil)) ; broadcast the constant
323- (ECons (MIter ) (ENil)))) ; produce per-batch vector
323+ (ECons (MNum 1 ) (ENil)))) ; produce per-batch vector
324324
325325 (= ?mean
326326 (Mul (ECons ?batch (ENil))
327- ?square_summed (ECons (MIter ) (ENil))
328- ?inv_div_factor (ECons (MIter ) (ENil))
329- (ECons (MIter ) (ENil))))
327+ ?square_summed (ECons (MNum 1 ) (ENil))
328+ ?inv_div_factor (ECons (MNum 1 ) (ENil))
329+ (ECons (MNum 1 ) (ENil))))
330330 (= ?eps_add
331331 (Add
332332 (ECons ?batch (ENil))
333333 ?mean
334- (ECons (MIter ) (ENil))
334+ (ECons (MNum 1 ) (ENil))
335335 (Constant ?eps)
336336 (ECons (MNum 0) (ENil))
337- (ECons (MIter ) (ENil))
337+ (ECons (MNum 1 ) (ENil))
338338 )
339339 )
340340 (= ?sqrt
341341 (Sqrt
342342 (ECons ?batch (ENil))
343343 ?eps_add
344- (ECons (MIter ) (ENil))
345- (ECons (MIter ) (ENil))
344+ (ECons (MNum 1 ) (ENil))
345+ (ECons (MNum 1 ) (ENil))
346346 )
347347 )
348348 (= ?recip
349349 (Recip
350350 (ECons ?batch (ENil))
351351 ?sqrt
352- (ECons (MIter ) (ENil))
353- (ECons (MIter ) (ENil))
352+ (ECons (MNum 1 ) (ENil))
353+ (ECons (MNum 1 ) (ENil))
354354 )
355355 )
356356 (= ?std_normed
357357 (Mul
358358 ?inp_range
359359 ?recip
360- (ECons (MIter ) (ECons (MNum 0) (ENil)))
360+ (ECons (MNum 1 ) (ECons (MNum 0) (ENil)))
361361 ?x
362362 ?inp_stride
363363 ?inp_stride
@@ -369,7 +369,7 @@ impl EgglogOp for RowRMSNorm {
369369 ?std_normed
370370 ?inp_stride
371371 ?weight
372- (ECons (MNum 0) (ECons (MIter ) (ENil)))
372+ (ECons (MNum 0) (ECons (MNum 1 ) (ENil)))
373373 ?inp_stride
374374 )
375375 )
@@ -380,7 +380,7 @@ impl EgglogOp for RowRMSNorm {
380380 (RowRMSNorm
381381 (ECons ?batch (ENil))
382382 ?x
383- (ECons (MMul (MIter) ?width) (ENil))
383+ (ECons ?width (ENil))
384384 ?width
385385 ?weight
386386 )
@@ -673,10 +673,10 @@ impl EgglogOp for TileMatmul {
673673 ; get tile sum
674674 (= ?ts (TileSum ?sum_shape ?untiled_sum_shape ?iters ?cm ?sum_in_stride ?sum_in_m_stride ?sum_in_n_stride ?sum_in_k_stride ?sum_out_stride ?sum_out_m_stride ?sum_out_n_stride))
675675 ; assert k stride on the intermediate is 1
676- (= ?out_k_stride (MIter ))
677- (= ?sum_in_k_stride (MIter ))
676+ (= ?out_k_stride (MNum 1 ))
677+ (= ?sum_in_k_stride (MNum 1 ))
678678 ; assert matmul strides
679- (= ?b_n_stride (MIter ))
679+ (= ?b_n_stride (MNum 1 ))
680680 ; get dimensions
681681 (= ?t_n (nth_from_end ?mul_shape 1))
682682 (= ?t_k (nth_from_end ?mul_shape 0))
@@ -706,14 +706,14 @@ impl EgglogOp for TileMatmul {
706706 ?sum_out_stride ?sum_out_m_stride ?sum_out_n_stride))
707707
708708 ; assert k stride on the intermediate is 1 (contiguous)
709- (= ?out_k_stride (MIter ))
710- (= ?sum_in_k_stride (MIter ))
709+ (= ?out_k_stride (MNum 1 ))
710+ (= ?sum_in_k_stride (MNum 1 ))
711711
712712 ; A row-major (contiguous in its last dim k)
713- (= ?a_k_stride (MIter ))
713+ (= ?a_k_stride (MNum 1 ))
714714
715715 ; B col-major (contiguous in its first dim k)
716- (= ?b_k_stride (MIter ))
716+ (= ?b_k_stride (MNum 1 ))
717717
718718 ; get tile dims
719719 (= ?t_n (nth_from_end ?mul_shape 1))
@@ -731,7 +731,7 @@ impl EgglogOp for TileMatmul {
731731 ; - C row-major tile strides: m -> t_n*32, n -> 1
732732 (let ?tm (TileMatmul ?sum_shape ?untiled_sum_shape ?iters
733733 ?a ?new_a_stride (MMul ?t_k (MNum 32)) (MNum 1)
734- ?b ?new_b_stride (MReplace ?b_k_stride (MIter) (MNum 1)) (MMul ?t_k (MNum 32))
734+ ?b ?new_b_stride ?b_k_stride (MMul ?t_k (MNum 32))
735735 ?sum_out_stride (MMul ?t_n (MNum 32)) (MNum 1)))
736736 (union ?ts ?tm)
737737 (set (dtype ?tm) (F32))
0 commit comments