Skip to content

Commit 04149a3

Browse files
authored
add test for week 1 day 5 test 1: Qwen2TransformerBlock (#59)
1 parent 1c9369a commit 04149a3

File tree

1 file changed

+75
-2
lines changed

1 file changed

+75
-2
lines changed

tests_refsol/test_week_1_day_5.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,81 @@
33
from .tiny_llm_base import Qwen2ModelWeek1, Embedding, dequantize_linear, qwen2_week1
44
from mlx_lm import load
55

6-
# TODO: task 1 tests
7-
6+
@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS)
7+
@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS)
8+
@pytest.mark.parametrize("mask", [None, "causal"], ids=["no_mask", "causal_mask"])
9+
def test_task_1_transformer_block(
10+
stream: mx.Stream, precision: mx.Dtype, mask: str | None
11+
):
12+
with mx.stream(stream):
13+
from mlx_lm.models import qwen2
14+
15+
BATCH_SIZE = 1
16+
SEQ_LEN = 10
17+
NUM_ATTENTION_HEAD = 4
18+
NUM_KV_HEADS = 2
19+
HIDDEN_SIZE = 32
20+
INTERMEDIATE_SIZE = HIDDEN_SIZE * 4
21+
22+
args = qwen2.ModelArgs(
23+
model_type="qwen2",
24+
hidden_size=HIDDEN_SIZE,
25+
num_hidden_layers=1,
26+
intermediate_size=INTERMEDIATE_SIZE,
27+
num_attention_heads=NUM_ATTENTION_HEAD,
28+
num_key_value_heads=NUM_KV_HEADS,
29+
rms_norm_eps=1e-6,
30+
vocab_size=1000,
31+
)
32+
33+
mlx_transformer_block = qwen2.TransformerBlock(args)
34+
35+
mlx_attention = mlx_transformer_block.self_attn
36+
wq = mlx_attention.q_proj.weight
37+
wk = mlx_attention.k_proj.weight
38+
wv = mlx_attention.v_proj.weight
39+
wo = mlx_attention.o_proj.weight
40+
bq = mlx_attention.q_proj.bias
41+
bk = mlx_attention.k_proj.bias
42+
bv = mlx_attention.v_proj.bias
43+
44+
mlx_mlp = mlx_transformer_block.mlp
45+
w_gate = mlx_mlp.gate_proj.weight
46+
w_up = mlx_mlp.up_proj.weight
47+
w_down = mlx_mlp.down_proj.weight
48+
49+
w_input_layernorm = mlx_transformer_block.input_layernorm.weight
50+
w_post_attention_layernorm = mlx_transformer_block.post_attention_layernorm.weight
51+
52+
user_transformer_block = qwen2_week1.Qwen2TransformerBlock(
53+
num_attention_heads=NUM_ATTENTION_HEAD,
54+
num_kv_heads=NUM_KV_HEADS,
55+
hidden_size=HIDDEN_SIZE,
56+
intermediate_size=INTERMEDIATE_SIZE,
57+
rms_norm_eps=1e-6,
58+
wq=wq,
59+
wk=wk,
60+
wv=wv,
61+
wo=wo,
62+
bq=bq,
63+
bk=bk,
64+
bv=bv,
65+
w_gate=w_gate,
66+
w_up=w_up,
67+
w_down=w_down,
68+
w_input_layernorm=w_input_layernorm,
69+
w_post_attention_layernorm=w_post_attention_layernorm
70+
)
71+
72+
mx.random.seed(42)
73+
x = mx.random.uniform(
74+
shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=precision
75+
)
76+
77+
user_output = user_transformer_block(x, mask=mask)
78+
mlx_output = mlx_transformer_block(x, mask=mask, cache=None)
79+
80+
assert_allclose(user_output, mlx_output, precision=precision, rtol=1e-1)
881

982
@pytest.mark.skipif(
1083
not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found"

0 commit comments

Comments
 (0)