|
3 | 3 | from .tiny_llm_base import Qwen2ModelWeek1, Embedding, dequantize_linear, qwen2_week1 |
4 | 4 | from mlx_lm import load |
5 | 5 |
|
6 | | -# TODO: task 1 tests |
7 | | - |
| 6 | +@pytest.mark.parametrize("stream", AVAILABLE_STREAMS, ids=AVAILABLE_STREAMS_IDS) |
| 7 | +@pytest.mark.parametrize("precision", PRECISIONS, ids=PRECISION_IDS) |
| 8 | +@pytest.mark.parametrize("mask", [None, "causal"], ids=["no_mask", "causal_mask"]) |
| 9 | +def test_task_1_transformer_block( |
| 10 | + stream: mx.Stream, precision: mx.Dtype, mask: str | None |
| 11 | +): |
| 12 | + with mx.stream(stream): |
| 13 | + from mlx_lm.models import qwen2 |
| 14 | + |
| 15 | + BATCH_SIZE = 1 |
| 16 | + SEQ_LEN = 10 |
| 17 | + NUM_ATTENTION_HEAD = 4 |
| 18 | + NUM_KV_HEADS = 2 |
| 19 | + HIDDEN_SIZE = 32 |
| 20 | + INTERMEDIATE_SIZE = HIDDEN_SIZE * 4 |
| 21 | + |
| 22 | + args = qwen2.ModelArgs( |
| 23 | + model_type="qwen2", |
| 24 | + hidden_size=HIDDEN_SIZE, |
| 25 | + num_hidden_layers=1, |
| 26 | + intermediate_size=INTERMEDIATE_SIZE, |
| 27 | + num_attention_heads=NUM_ATTENTION_HEAD, |
| 28 | + num_key_value_heads=NUM_KV_HEADS, |
| 29 | + rms_norm_eps=1e-6, |
| 30 | + vocab_size=1000, |
| 31 | + ) |
| 32 | + |
| 33 | + mlx_transformer_block = qwen2.TransformerBlock(args) |
| 34 | + |
| 35 | + mlx_attention = mlx_transformer_block.self_attn |
| 36 | + wq = mlx_attention.q_proj.weight |
| 37 | + wk = mlx_attention.k_proj.weight |
| 38 | + wv = mlx_attention.v_proj.weight |
| 39 | + wo = mlx_attention.o_proj.weight |
| 40 | + bq = mlx_attention.q_proj.bias |
| 41 | + bk = mlx_attention.k_proj.bias |
| 42 | + bv = mlx_attention.v_proj.bias |
| 43 | + |
| 44 | + mlx_mlp = mlx_transformer_block.mlp |
| 45 | + w_gate = mlx_mlp.gate_proj.weight |
| 46 | + w_up = mlx_mlp.up_proj.weight |
| 47 | + w_down = mlx_mlp.down_proj.weight |
| 48 | + |
| 49 | + w_input_layernorm = mlx_transformer_block.input_layernorm.weight |
| 50 | + w_post_attention_layernorm = mlx_transformer_block.post_attention_layernorm.weight |
| 51 | + |
| 52 | + user_transformer_block = qwen2_week1.Qwen2TransformerBlock( |
| 53 | + num_attention_heads=NUM_ATTENTION_HEAD, |
| 54 | + num_kv_heads=NUM_KV_HEADS, |
| 55 | + hidden_size=HIDDEN_SIZE, |
| 56 | + intermediate_size=INTERMEDIATE_SIZE, |
| 57 | + rms_norm_eps=1e-6, |
| 58 | + wq=wq, |
| 59 | + wk=wk, |
| 60 | + wv=wv, |
| 61 | + wo=wo, |
| 62 | + bq=bq, |
| 63 | + bk=bk, |
| 64 | + bv=bv, |
| 65 | + w_gate=w_gate, |
| 66 | + w_up=w_up, |
| 67 | + w_down=w_down, |
| 68 | + w_input_layernorm=w_input_layernorm, |
| 69 | + w_post_attention_layernorm=w_post_attention_layernorm |
| 70 | + ) |
| 71 | + |
| 72 | + mx.random.seed(42) |
| 73 | + x = mx.random.uniform( |
| 74 | + shape=(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE), dtype=precision |
| 75 | + ) |
| 76 | + |
| 77 | + user_output = user_transformer_block(x, mask=mask) |
| 78 | + mlx_output = mlx_transformer_block(x, mask=mask, cache=None) |
| 79 | + |
| 80 | + assert_allclose(user_output, mlx_output, precision=precision, rtol=1e-1) |
8 | 81 |
|
9 | 82 | @pytest.mark.skipif( |
10 | 83 | not qwen_2_05b_model_exists(), reason="Qwen2-0.5B-Instruct-MLX model not found" |
|
0 commit comments