Skip to content

Commit c128fe9

Browse files
committed
assume dim the same
Signed-off-by: Alex Chi <[email protected]>
1 parent 0607e8b commit c128fe9

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

book/src/week1-01-attention.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ You may use `softmax` provided by mlx and implement it later in week 2.
3737
**📚 Readings**
3838

3939
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
40-
* [PyTorch API](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (assume `enable_gqa=False`)
41-
* [MLX API](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html)
40+
* [PyTorch API](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (assume `enable_gqa=False`, assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
41+
* [MLX API](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.fast.scaled_dot_product_attention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
4242
* [Attention is All You Need](https://arxiv.org/abs/1706.03762)
4343

4444
## Task 2: Implement `MultiHeadAttention`
@@ -77,15 +77,15 @@ transpose it to get the right shape.
7777
**📚 Readings**
7878

7979
* [Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)
80-
* [PyTorch API](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)
81-
* [MLX API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html)
80+
* [PyTorch API](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
81+
* [MLX API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
8282

8383
At the end of the day, you should be able to pass the following tests:
8484

8585
```
8686
poetry run pytest tests -k test_attention_simple
8787
poetry run pytest tests -k test_attention_with_mask
88-
poetry run pytest tests-k test_multi_head_attention
88+
poetry run pytest tests -k test_multi_head_attention
8989
```
9090

9191
{{#include copyright.md}}

0 commit comments

Comments
 (0)