@@ -31,7 +31,7 @@ fn intsqrt[n: Int]() -> Int:
3131
3232
3333@register_passable (" trivial" )
34- struct Layout (Copyable , Movable , Writable ):
34+ struct Layout (Copyable , Writable ):
3535 var shape : IndexList[2 ]
3636 var strides : IndexList[2 ]
3737
@@ -57,31 +57,31 @@ struct Layout(Copyable, Movable, Writable):
5757
5858
5959struct Matrix[Type: DType]:
60- var data : UnsafePointer[Scalar[Type], MutAnyOrigin]
60+ var data : UnsafePointer[Scalar[Self. Type], MutAnyOrigin]
6161 var layout : Layout
6262
6363 fn __init__ (out self , shape : Tuple[Int, Int]):
64- self .data = alloc[Scalar[Type]](shape[0 ] * shape[1 ])
64+ self .data = alloc[Scalar[Self. Type]](shape[0 ] * shape[1 ])
6565 self .layout = Layout(shape)
6666
6767 @always_inline (" nodebug" )
6868 fn __init__ (
69- out self , data : UnsafePointer[Scalar[Type], MutAnyOrigin], var layout : Layout
69+ out self , data : UnsafePointer[Scalar[Self. Type], MutAnyOrigin], var layout : Layout
7070 ):
7171 self .data = data
7272 self .layout = layout
7373
7474 @always_inline (" nodebug" )
7575 fn __init__ (
76- out self , data : UnsafePointer[Scalar[Type], MutAnyOrigin], shape : Tuple[Int, Int]
76+ out self , data : UnsafePointer[Scalar[Self. Type], MutAnyOrigin], shape : Tuple[Int, Int]
7777 ):
7878 self .data = data
7979 self .layout = Layout(shape)
8080
8181 @always_inline (" nodebug" )
8282 fn __getitem__ (
8383 ref [_]self , i : Int, j : Int
84- ) -> ref [origin_of(self )] Scalar[Type]:
84+ ) -> ref [origin_of(self )] Scalar[Self. Type]:
8585 var offset = self .layout(i, j)
8686 return (self .data + offset)[]
8787
@@ -104,7 +104,7 @@ struct Matrix[Type: DType]:
104104 random.rand(self .data, self .layout.size())
105105
106106 @always_inline (" nodebug" )
107- fn load [width : Int, *, dim : Int](self , i : Int, j : Int) -> SIMD [Type, width]:
107+ fn load [width : Int, *, dim : Int](self , i : Int, j : Int) -> SIMD [Self. Type, width]:
108108 var offset = self .layout(i, j)
109109 var ptr = self .data + offset
110110
@@ -117,7 +117,7 @@ struct Matrix[Type: DType]:
117117 @always_inline (" nodebug" )
118118 fn store [
119119 width : Int, *, dim : Int
120- ](self , value : SIMD [Type, width], i : Int, j : Int):
120+ ](self , value : SIMD [Self. Type, width], i : Int, j : Int):
121121 var offset = self .layout(i, j)
122122 var ptr = self .data + offset
123123
@@ -149,19 +149,20 @@ fn pack_A[
149149 fn pack_panel (idx : Int):
150150 var i = idx * mr
151151 # for i in range(0, Ac.shape[0](), mr):
152+ var Ac_stride = Ac.stride[0 ]()
152153 var dst_ptr = Ac_buffer + i * Ac.shape[1 ]()
153- var src_ptr = Ac.data + i * Ac.stride[ 0 ]()
154+ var src_ptr = Ac.data + i * Ac_stride
154155 for _ in range (Ac.shape[1 ]()):
155156
156157 @parameter
157- fn pack_col [width : Int](l : Int):
158+ fn pack_col [width : Int](l : Int) unified { mut } :
158159 (dst_ptr + l).store(
159- (src_ptr + l * Ac.stride[ 0 ]() ).strided_load[
160+ (src_ptr + l * Ac_stride ).strided_load[
160161 width=width
161- ](Ac.stride[ 0 ]() ),
162+ ](Ac_stride ),
162163 )
163164
164- vectorize[pack_col, simd_width_of[Type]()](min (Ac.shape[0 ]() - i, mr))
165+ vectorize[simd_width_of[Type]()](min (Ac.shape[0 ]() - i, mr), pack_col )
165166
166167 for l in range (min (Ac.shape[0 ]() - i, mr), mr):
167168 dst_ptr[l] = Scalar[Type](0 )
@@ -189,18 +190,17 @@ fn pack_B[
189190 for _ in range (Bc.shape[0 ]()):
190191
191192 @parameter
192- fn pack_row [width : Int](l : Int):
193+ fn pack_row [width : Int](l : Int) unified { mut } :
193194 (dst_ptr + l).store[
194195 alignment = size_of[Type]() * simd_width_of[Type]()
195196 ](
196197 (src_ptr + l).load[width=width](),
197198 )
198199
199200 vectorize[
200- pack_row,
201201 simd_width_of[Type](),
202202 unroll_factor = nr // simd_width_of[Type](),
203- ](min (Bc.shape[1 ]() - i, nr))
203+ ](min (Bc.shape[1 ]() - i, nr), pack_row )
204204
205205 for l in range (min (Bc.shape[1 ]() - i, nr), nr):
206206 dst_ptr[l] = Scalar[Type](0 )
@@ -367,12 +367,12 @@ fn micro_kernel[
367367 if i < Cr.shape[0 ]():
368368
369369 @parameter
370- fn load_col [width : Int](j : Int):
370+ fn load_col [width : Int](j : Int) unified { mut } :
371371 (cr_ptr + (i * nr + j)).store(
372372 (Cr_ptr + (i * Cr.stride[0 ]() + j)).load[width=width](),
373373 )
374374
375- vectorize[load_col, simd_width](Cr.shape[1 ]())
375+ vectorize[simd_width](Cr.shape[1 ](), load_col )
376376 else :
377377
378378 @parameter
@@ -418,12 +418,12 @@ fn micro_kernel[
418418 if i < Cr.shape[0 ]():
419419
420420 @parameter
421- fn store_row [width : Int](j : Int):
421+ fn store_row [width : Int](j : Int) unified { mut } :
422422 (Cr_ptr + (i * Cr.stride[0 ]() + j)).store(
423423 (cr_ptr + (i * nr + j)).load[width=width](),
424424 )
425425
426- vectorize[store_row, simd_width](Cr.shape[1 ]())
426+ vectorize[simd_width](Cr.shape[1 ](), store_row )
427427 else :
428428
429429 @parameter
@@ -467,7 +467,7 @@ fn matmul_params[Type: DType]() -> IndexList[5]:
467467 else :
468468
469469 @parameter
470- if Type is DType.int64:
470+ if Type == DType.int64:
471471
472472 @parameter
473473 if CompilationTarget.has_avx512f():
0 commit comments