You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
*[PyTorch MultiHeadAttention API](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
71
+
*[MLX MultiHeadAttention API](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.MultiHeadAttention.html) (assume dim_k=dim_v=dim_q and H_k=H_v=H_q)
72
+
46
73
Implement `MultiHeadAttention`. The layer takes a batch of vectors `x`, maps it through the K, V, Q weight matrixes, and
47
74
use the attention function we implemented in day 1 to compute the result. The output needs to be mapped using the O
48
-
weight matrix. You will also need to implement the `linear` function.
75
+
weight matrix.
49
76
50
-
For `linear`, it takes a tensor of the shape `N.. x I`, and a weight matrix of the shape `O x I`, and a bias vector of
51
-
the shape `O`. The output is of the shape `N.. x O`. `I` is the input dimension and `O` is the output dimension.
77
+
You will also need to implement the `linear` function first. For `linear`, it takes a tensor of the shape `N.. x I`, a weight matrix of the shape `O x I`, and a bias vector of the shape `O`. The output is of the shape `N.. x O`. `I` is the input dimension and `O` is the output dimension.
52
78
53
79
For the `MultiHeadAttention` layer, the input tensor `x` has the shape `N x L x E`, where `E` is the dimension of the
54
-
embedding for a given head in the sequence. The `K/Q/V` weight matrixes will map the tensor into key, value, and query
55
-
separately, where the dimension `E` will be mapped into a dimension of size `H x D`. Then, you will need to reshape it
56
-
to `H, D`
80
+
embedding for a given token in the sequence. The `K/Q/V` weight matrixes will map the tensor into key, value, and query
81
+
separately, where the dimension `E` will be mapped into a dimension of size `H x D`, which means that the token embedding
82
+
gets mapped into `H` heads, each with a dimension of `D`. You can directly reshape the tensor to split the `H x D` dimension
83
+
into two dimensions of `H` and `D` to get `H` heads for the token. Then, apply the attention function to each of the head
84
+
(this requires a transpose, using `swapaxes` in mlx). The attention function takes `N.. x H x L x D` as input so that it
85
+
produces an output for each of the head of the token. Then, you can transpose it into `N.. x L x H x D` and reshape it
86
+
so that all heads get merged back together with a shape of `N.. x L x (H x D)`. Map it through the output weight matrix to get
87
+
the final output.
57
88
58
89
```
59
90
E is hidden_size or embed_dim or dims or model_dim
60
91
H is num_heads
61
92
D is head_dim
62
93
L is seq_len, in PyTorch API it's S (source len)
63
94
95
+
W_q/k/v: E x (H x D)
64
96
x: N x L x E
65
-
w_q/k/v: E x (H x D)
66
-
q/k/v = linear(x, w_q/w_k/w_v) = N x L x (H x D)
67
-
then, reshape it into N x L x H x D then transpose it to get N x H x L x D as the input of the attention function.
68
-
o = attention(q, k, v) = N x H x L x D
69
-
w_o: (H x D) x O
70
-
result = linear(reshaped o, w_o) = N x L x O
97
+
W_o: (H x D) x O
71
98
```
72
99
73
-
You can then directly split the `q/k/v` into `H` heads by reshaping the last dimension into `H x D` and apply the
74
-
attention function on it. Note that the attention function takes `N.. x H x L x D` as input, so you will need to
0 commit comments