8
8
9
9
@pytest .mark .parametrize ("stream" , AVAILABLE_STREAMS , ids = AVAILABLE_STREAMS_IDS )
10
10
@pytest .mark .parametrize ("precision" , PRECISIONS , ids = PRECISION_IDS )
11
- def test_attention_simple (stream : mx .Stream , precision : np .dtype ):
11
+ def test_attention_week_1_day_1_task_1 (stream : mx .Stream , precision : np .dtype ):
12
12
with mx .stream (stream ):
13
13
BATCH_SIZE = 3
14
14
DIM_N = 4
@@ -35,18 +35,18 @@ def test_attention_simple(stream: mx.Stream, precision: np.dtype):
35
35
@pytest .mark .parametrize (
36
36
"qkv_shape" , [True , False ], ids = ["with_seq_len" , "without_seq_len" ]
37
37
)
38
- def test_attention_with_mask (stream : mx .Stream , precision : np .dtype , qkv_shape : bool ):
38
+ def test_attention_with_mask_week_1_day_1_task_1 (stream : mx .Stream , precision : np .dtype , qkv_shape : bool ):
39
39
with mx .stream (stream ):
40
40
BATCH_SIZE = 3
41
41
SEQ_LEN = 10
42
- DIM_N = 4
43
- DIM_M = 5
42
+ H = 4
43
+ D = 5
44
44
if qkv_shape :
45
- qkv_shape = (BATCH_SIZE , SEQ_LEN , DIM_N , DIM_M )
46
- mask_shape = (BATCH_SIZE , SEQ_LEN , DIM_N , DIM_N )
45
+ qkv_shape = (BATCH_SIZE , H , SEQ_LEN , D )
46
+ mask_shape = (BATCH_SIZE , H , SEQ_LEN , SEQ_LEN )
47
47
else :
48
- qkv_shape = (BATCH_SIZE , DIM_N , DIM_M )
49
- mask_shape = (BATCH_SIZE , DIM_N , DIM_N )
48
+ qkv_shape = (BATCH_SIZE , H , SEQ_LEN , D )
49
+ mask_shape = (BATCH_SIZE , H , SEQ_LEN , SEQ_LEN )
50
50
for _ in range (100 ):
51
51
query = np .random .rand (* qkv_shape ).astype (precision )
52
52
key = np .random .rand (* qkv_shape ).astype (precision )
@@ -72,33 +72,31 @@ def test_attention_with_mask(stream: mx.Stream, precision: np.dtype, qkv_shape:
72
72
73
73
@pytest .mark .parametrize ("stream" , AVAILABLE_STREAMS , ids = AVAILABLE_STREAMS_IDS )
74
74
@pytest .mark .parametrize ("precision" , PRECISIONS , ids = PRECISION_IDS )
75
- def test_multi_head_attention (stream : mx .Stream , precision : np .dtype ):
75
+ def test_multi_head_attention_week_1_day_1_task_2 (stream : mx .Stream , precision : np .dtype ):
76
76
with mx .stream (stream ):
77
- BATCH_SIZE = 7
78
- DIM_N = 11
79
- DIM_M = 9
80
- NUM_HEADS = 3
77
+ SEQ_LEN = 11
78
+ D = 9
79
+ H = 3
80
+ BATCH_SIZE = 10
81
81
for _ in range (100 ):
82
- query = np .random .rand (BATCH_SIZE , DIM_N , DIM_M ).astype (precision )
83
- key = np .random .rand (BATCH_SIZE , DIM_N , DIM_M ).astype (precision )
84
- value = np .random .rand (BATCH_SIZE , DIM_N , DIM_M ).astype (precision )
85
- q_proj_weight = np .random .rand (DIM_M , DIM_M ).astype (precision )
86
- k_proj_weight = np .random .rand (DIM_M , DIM_M ).astype (precision )
87
- v_proj_weight = np .random .rand (DIM_M , DIM_M ).astype (precision )
88
- out_proj_weight = np .random .rand (DIM_M , DIM_M ).astype (precision )
89
- mask = np .random .rand (DIM_N * NUM_HEADS , BATCH_SIZE , BATCH_SIZE ).astype (
90
- precision
91
- )
82
+ query = np .random .rand (BATCH_SIZE , SEQ_LEN , H * D ).astype (precision )
83
+ key = np .random .rand (BATCH_SIZE , SEQ_LEN , H * D ).astype (precision )
84
+ value = np .random .rand (BATCH_SIZE , SEQ_LEN , H * D ).astype (precision )
85
+ q_proj_weight = np .random .rand (H * D , H * D ).astype (precision )
86
+ k_proj_weight = np .random .rand (H * D , H * D ).astype (precision )
87
+ v_proj_weight = np .random .rand (H * D , H * D ).astype (precision )
88
+ out_proj_weight = np .random .rand (H * D , H * D ).astype (precision )
89
+ mask = np .random .rand (SEQ_LEN , SEQ_LEN ).astype (precision )
92
90
reference_output , _ = torch .nn .functional .multi_head_attention_forward (
93
- torch .tensor (query , device = TORCH_DEVICE ),
94
- torch .tensor (key , device = TORCH_DEVICE ),
95
- torch .tensor (value , device = TORCH_DEVICE ),
96
- num_heads = NUM_HEADS ,
91
+ torch .tensor (query , device = TORCH_DEVICE ). transpose ( 0 , 1 ) ,
92
+ torch .tensor (key , device = TORCH_DEVICE ). transpose ( 0 , 1 ) ,
93
+ torch .tensor (value , device = TORCH_DEVICE ). transpose ( 0 , 1 ) ,
94
+ num_heads = H ,
97
95
q_proj_weight = torch .tensor (q_proj_weight , device = TORCH_DEVICE ),
98
96
k_proj_weight = torch .tensor (k_proj_weight , device = TORCH_DEVICE ),
99
97
v_proj_weight = torch .tensor (v_proj_weight , device = TORCH_DEVICE ),
100
98
out_proj_weight = torch .tensor (out_proj_weight , device = TORCH_DEVICE ),
101
- embed_dim_to_check = DIM_M ,
99
+ embed_dim_to_check = H * D ,
102
100
in_proj_weight = None ,
103
101
in_proj_bias = None ,
104
102
bias_k = None ,
@@ -109,9 +107,10 @@ def test_multi_head_attention(stream: mx.Stream, precision: np.dtype):
109
107
use_separate_proj_weight = True ,
110
108
attn_mask = torch .tensor (mask , device = TORCH_DEVICE ),
111
109
)
110
+ reference_output = reference_output .transpose (0 , 1 )
112
111
user_output = MultiHeadAttention (
113
- DIM_M ,
114
- NUM_HEADS ,
112
+ H * D ,
113
+ H ,
115
114
mx .array (q_proj_weight ),
116
115
mx .array (k_proj_weight ),
117
116
mx .array (v_proj_weight ),
0 commit comments