Skip to content

Commit 99c988a

Browse files
committed
cuda stride conversion
1 parent 8564970 commit 99c988a

File tree

1 file changed

+37
-37
lines changed
  • crates/luminal_cuda/src/block

1 file changed

+37
-37
lines changed

crates/luminal_cuda/src/block/ops.rs

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ impl EgglogOp for RowAdd {
4949
(= ?sa (Add ?shape ?a ?a_stride ?b ?b_stride ?out_stride))
5050
(= ?row_width (nth_from_end ?shape 0))
5151
; assert the row is contiguous
52-
(= (MIter) (nth_from_end ?a_stride 0))
53-
(= (MIter) (nth_from_end ?b_stride 0))
54-
(= (MIter) (nth_from_end ?out_stride 0))
52+
(= (MNum 1) (nth_from_end ?a_stride 0))
53+
(= (MNum 1) (nth_from_end ?b_stride 0))
54+
(= (MNum 1) (nth_from_end ?out_stride 0))
5555
;(= (F32) (dtype ?a))
5656
)
5757
(
@@ -167,34 +167,34 @@ impl EgglogOp for RowSwishMul {
167167
(= ?sigmoid (Sigmoid
168168
(ECons ?batch (ECons ?width (ENil)))
169169
?self
170-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
171-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
170+
(ECons ?width (ECons (MNum 1) (ENil)))
171+
(ECons ?width (ECons (MNum 1) (ENil)))
172172
))
173173
(= ?swish (Mul
174174
(ECons ?batch (ECons ?width (ENil)))
175175
?self
176-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
176+
(ECons ?width (ECons (MNum 1) (ENil)))
177177
?sigmoid
178-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
179-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
178+
(ECons ?width (ECons (MNum 1) (ENil)))
179+
(ECons ?width (ECons (MNum 1) (ENil)))
180180
))
181181
(= ?swishmul (Mul
182182
(ECons ?batch (ECons ?width (ENil)))
183183
?swish
184-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
184+
(ECons ?width (ECons (MNum 1) (ENil)))
185185
?other
186-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
187-
(ECons (MMul (MIter) ?width) (ECons (MIter) (ENil)))
186+
(ECons ?width (ECons (MNum 1) (ENil)))
187+
(ECons ?width (ECons (MNum 1) (ENil)))
188188
))
189189
;(= (F32) (dtype ?self))
190190
)
191191
(
192192
(let ?rsm (RowSwishMul
193193
(ECons ?batch (ENil))
194194
?self
195-
(ECons (MMul (MIter) ?width) (ENil))
195+
(ECons ?width (ENil))
196196
?other
197-
(ECons (MMul (MIter) ?width) (ENil))
197+
(ECons ?width (ENil))
198198
?width
199199
))
200200
(union ?swishmul ?rsm)
@@ -312,52 +312,52 @@ impl EgglogOp for RowRMSNorm {
312312
(ECons ?batch (ENil))
313313
?width
314314
?square
315-
(ECons (MMul (MIter) ?width) (ENil))
316-
(MIter)
317-
(ECons (MIter) (ENil))
315+
(ECons ?width (ENil))
316+
(MNum 1)
317+
(ECons (MNum 1) (ENil))
318318
)
319319
)
320320
(= ?inv_div_factor
321321
(Recip (ECons ?batch (ENil)) (Cast (Iota ?width (MNum 1)) (F32))
322322
(ECons (MNum 0) (ENil)) ; broadcast the constant
323-
(ECons (MIter) (ENil)))) ; produce per-batch vector
323+
(ECons (MNum 1) (ENil)))) ; produce per-batch vector
324324
325325
(= ?mean
326326
(Mul (ECons ?batch (ENil))
327-
?square_summed (ECons (MIter) (ENil))
328-
?inv_div_factor (ECons (MIter) (ENil))
329-
(ECons (MIter) (ENil))))
327+
?square_summed (ECons (MNum 1) (ENil))
328+
?inv_div_factor (ECons (MNum 1) (ENil))
329+
(ECons (MNum 1) (ENil))))
330330
(= ?eps_add
331331
(Add
332332
(ECons ?batch (ENil))
333333
?mean
334-
(ECons (MIter) (ENil))
334+
(ECons (MNum 1) (ENil))
335335
(Constant ?eps)
336336
(ECons (MNum 0) (ENil))
337-
(ECons (MIter) (ENil))
337+
(ECons (MNum 1) (ENil))
338338
)
339339
)
340340
(= ?sqrt
341341
(Sqrt
342342
(ECons ?batch (ENil))
343343
?eps_add
344-
(ECons (MIter) (ENil))
345-
(ECons (MIter) (ENil))
344+
(ECons (MNum 1) (ENil))
345+
(ECons (MNum 1) (ENil))
346346
)
347347
)
348348
(= ?recip
349349
(Recip
350350
(ECons ?batch (ENil))
351351
?sqrt
352-
(ECons (MIter) (ENil))
353-
(ECons (MIter) (ENil))
352+
(ECons (MNum 1) (ENil))
353+
(ECons (MNum 1) (ENil))
354354
)
355355
)
356356
(= ?std_normed
357357
(Mul
358358
?inp_range
359359
?recip
360-
(ECons (MIter) (ECons (MNum 0) (ENil)))
360+
(ECons (MNum 1) (ECons (MNum 0) (ENil)))
361361
?x
362362
?inp_stride
363363
?inp_stride
@@ -369,7 +369,7 @@ impl EgglogOp for RowRMSNorm {
369369
?std_normed
370370
?inp_stride
371371
?weight
372-
(ECons (MNum 0) (ECons (MIter) (ENil)))
372+
(ECons (MNum 0) (ECons (MNum 1) (ENil)))
373373
?inp_stride
374374
)
375375
)
@@ -380,7 +380,7 @@ impl EgglogOp for RowRMSNorm {
380380
(RowRMSNorm
381381
(ECons ?batch (ENil))
382382
?x
383-
(ECons (MMul (MIter) ?width) (ENil))
383+
(ECons ?width (ENil))
384384
?width
385385
?weight
386386
)
@@ -673,10 +673,10 @@ impl EgglogOp for TileMatmul {
673673
; get tile sum
674674
(= ?ts (TileSum ?sum_shape ?untiled_sum_shape ?iters ?cm ?sum_in_stride ?sum_in_m_stride ?sum_in_n_stride ?sum_in_k_stride ?sum_out_stride ?sum_out_m_stride ?sum_out_n_stride))
675675
; assert k stride on the intermediate is 1
676-
(= ?out_k_stride (MIter))
677-
(= ?sum_in_k_stride (MIter))
676+
(= ?out_k_stride (MNum 1))
677+
(= ?sum_in_k_stride (MNum 1))
678678
; assert matmul strides
679-
(= ?b_n_stride (MIter))
679+
(= ?b_n_stride (MNum 1))
680680
; get dimensions
681681
(= ?t_n (nth_from_end ?mul_shape 1))
682682
(= ?t_k (nth_from_end ?mul_shape 0))
@@ -706,14 +706,14 @@ impl EgglogOp for TileMatmul {
706706
?sum_out_stride ?sum_out_m_stride ?sum_out_n_stride))
707707
708708
; assert k stride on the intermediate is 1 (contiguous)
709-
(= ?out_k_stride (MIter))
710-
(= ?sum_in_k_stride (MIter))
709+
(= ?out_k_stride (MNum 1))
710+
(= ?sum_in_k_stride (MNum 1))
711711
712712
; A row-major (contiguous in its last dim k)
713-
(= ?a_k_stride (MIter))
713+
(= ?a_k_stride (MNum 1))
714714
715715
; B col-major (contiguous in its first dim k)
716-
(= ?b_k_stride (MIter))
716+
(= ?b_k_stride (MNum 1))
717717
718718
; get tile dims
719719
(= ?t_n (nth_from_end ?mul_shape 1))
@@ -731,7 +731,7 @@ impl EgglogOp for TileMatmul {
731731
; - C row-major tile strides: m -> t_n*32, n -> 1
732732
(let ?tm (TileMatmul ?sum_shape ?untiled_sum_shape ?iters
733733
?a ?new_a_stride (MMul ?t_k (MNum 32)) (MNum 1)
734-
?b ?new_b_stride (MReplace ?b_k_stride (MIter) (MNum 1)) (MMul ?t_k (MNum 32))
734+
?b ?new_b_stride ?b_k_stride (MMul ?t_k (MNum 32))
735735
?sum_out_stride (MMul ?t_n (MNum 32)) (MNum 1)))
736736
(union ?ts ?tm)
737737
(set (dtype ?tm) (F32))

0 commit comments

Comments
 (0)