@@ -23,10 +23,12 @@ def test_grouped_gemm_2d_3d(use_fast_accum, strided):
23
23
device = "cuda"
24
24
s_int = int (strided )
25
25
m , n , k , n_groups = 16 , 32 , 16 , 4
26
- a = torch .randn (m * n_groups , k * (1 + s_int ), device = device , requires_grad = True )[:, :k ]
27
- b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device , requires_grad = True )[
28
- :: (1 + s_int ), :, :k
26
+ a = torch .randn (m * n_groups , k * (1 + s_int ), device = device , requires_grad = True )[
27
+ :, :k
29
28
]
29
+ b = torch .randn (
30
+ n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device , requires_grad = True
31
+ )[:: (1 + s_int ), :, :k ]
30
32
offs = torch .arange (m , n_groups * m + 1 , m , device = "cuda" , dtype = torch .int32 )
31
33
result = _grouped_scaled_mm (
32
34
a ,
@@ -62,12 +64,12 @@ def test_grouped_gemm_3d_3d(use_fast_accum, strided):
62
64
device = "cuda"
63
65
s_int = int (strided )
64
66
m , n , k , n_groups = 16 , 32 , 16 , 4
65
- a = torch .randn (n_groups * ( 1 + s_int ), m , k * ( 1 + s_int ), device = device , requires_grad = True )[
66
- :: (1 + s_int ), :, : k
67
- ]
68
- b = torch .randn (n_groups * ( 1 + s_int ), n , k * ( 1 + s_int ), device = device , requires_grad = True )[
69
- :: (1 + s_int ), :, : k
70
- ]
67
+ a = torch .randn (
68
+ n_groups * (1 + s_int ), m , k * ( 1 + s_int ), device = device , requires_grad = True
69
+ )[:: ( 1 + s_int ), :, : k ]
70
+ b = torch .randn (
71
+ n_groups * (1 + s_int ), n , k * ( 1 + s_int ), device = device , requires_grad = True
72
+ )[:: ( 1 + s_int ), :, : k ]
71
73
result = _grouped_scaled_mm (
72
74
a ,
73
75
b .transpose (- 2 , - 1 ),
@@ -99,12 +101,12 @@ def test_grouped_gemm_2d_2d(use_fast_accum, strided):
99
101
out_dtype = torch .bfloat16
100
102
device = "cuda"
101
103
m , n , k , n_groups = 16 , 16 , 16 , 4 # all sizes have to be divisible by 16
102
- a = torch .randn (m , k * n_groups + k * int ( strided ), device = device , requires_grad = True )[
103
- :, : k * n_groups
104
- ]
105
- b = torch .randn (n , k * n_groups + k * int ( strided ), device = device , requires_grad = True )[
106
- :, : k * n_groups
107
- ]
104
+ a = torch .randn (
105
+ m , k * n_groups + k * int ( strided ), device = device , requires_grad = True
106
+ )[:, : k * n_groups ]
107
+ b = torch .randn (
108
+ n , k * n_groups + k * int ( strided ), device = device , requires_grad = True
109
+ )[:, : k * n_groups ]
108
110
offs = torch .arange (k , n_groups * k + 1 , k , device = device , dtype = torch .int32 )
109
111
110
112
# Compute result.
0 commit comments