Skip to content

Commit fc1c497

Browse files
committed
Small attention layer exercise
1 parent cb47d2c commit fc1c497

File tree

1 file changed

+211
-0
lines changed

1 file changed

+211
-0
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"<img style=\"float: right;\" src=\"../../assets/htwlogo.svg\">\n",
8+
"\n",
9+
"# Exercise: Studying Attention Layers\n",
10+
"\n",
11+
"**Author**: _Erik Rodner_ <br>\n",
12+
"\n",
13+
"In this exercise, we will analyze the scaled dot-product attention.\n"
14+
]
15+
},
16+
{
17+
"cell_type": "code",
18+
"execution_count": null,
19+
"metadata": {},
20+
"outputs": [],
21+
"source": [
22+
"import numpy as np\n",
23+
"import matplotlib.pyplot as plt\n",
24+
"import torch\n",
25+
"import torch.nn.functional as F\n",
26+
"from transformers import BertTokenizer"
27+
]
28+
},
29+
{
30+
"cell_type": "markdown",
31+
"metadata": {},
32+
"source": [
33+
"## Tokenization\n",
34+
"\n",
35+
"Let's first tokenize some text without any purpose really :)"
36+
]
37+
},
38+
{
39+
"cell_type": "code",
40+
"execution_count": null,
41+
"metadata": {},
42+
"outputs": [],
43+
"source": [
44+
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
45+
"\n",
46+
"# Tokenization and input preparation\n",
47+
"sentence = \"Transformers are powerful models for natural language processing.\"\n",
48+
"tokens = tokenizer.tokenize(sentence)\n",
49+
"input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
50+
"input_tensor = torch.tensor([input_ids])\n",
51+
"\n",
52+
"print(f\"Sentence: '{sentence}'\")\n",
53+
"print(f\"Tokens: {tokens}\")\n",
54+
"print(f\"Input IDs: {input_ids}\")"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Generate synthetic embedding data \n",
62+
"\n",
63+
"For simplicity, we'll use random values with a rather low dimension here. \n",
64+
"In a real setting, the embeddings could be initially also random but also tuned during training."
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"metadata": {},
71+
"outputs": [],
72+
"source": [
73+
"embedding_dim = 8\n",
74+
"# the following construction also ignores the fact that initially embeddings should be the same for the same token\n",
75+
"data = torch.rand((len(input_ids), embedding_dim))\n",
76+
"print(f\"\\nGenerated Embedding Shape: {data.shape}\")"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"## Transformer Layer in Action: Scaled Dot Product Attention\n",
84+
"\n",
85+
"Let's first generate queries, keys, and values.\n",
86+
"Our $Q$, $K$, $V$ matrices are then computed by applying the embedding matrix to them."
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"metadata": {},
93+
"outputs": [],
94+
"source": [
95+
"dk = 4 # dimension of the query and key vectors\n",
96+
"dv = 4 # dimension of the value vectors\n",
97+
"query_weights = torch.rand((embedding_dim, dk))\n",
98+
"key_weights = torch.rand((embedding_dim, dk))\n",
99+
"value_weights = torch.rand((embedding_dim, dv))\n",
100+
"\n",
101+
"Q = torch.matmul(data, query_weights)\n",
102+
"K = torch.matmul(data, key_weights)\n",
103+
"V = torch.matmul(data, value_weights)\n",
104+
"\n",
105+
"print(f\"Query (Q) Shape: {Q.shape}\\n\", Q)\n",
106+
"print(f\"Key (K) Shape: {K.shape}\\n\", K)\n",
107+
"print(f\"Value (V) Shape: {V.shape}\\n\", V)"
108+
]
109+
},
110+
{
111+
"cell_type": "markdown",
112+
"metadata": {},
113+
"source": [
114+
"## Scaled dot-product attention\n",
115+
"\n",
116+
"Let's apply scaled dot-product attention step-by-step.\n",
117+
"\n",
118+
"**Exercise 1**: complete the following function to compute the attention scores"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": null,
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"def compute_attention_scores(Q, K):\n",
128+
" dk = Q.size(-1)\n",
129+
" scores = 0 # YOUR CODE HERE: compute the dot product between Q and K properly :)\n",
130+
" attn_probs = F.softmax(scores, dim=-1)\n",
131+
" return attn_probs\n",
132+
"\n",
133+
"attention_scores = compute_attention_scores(Q, K)\n",
134+
"print(f\"Attention Scores Shape: {attention_scores.shape}\\n\", attention_scores)"
135+
]
136+
},
137+
{
138+
"cell_type": "markdown",
139+
"metadata": {},
140+
"source": [
141+
"**Exercise 2**: complete now the following function to compute the final embedding."
142+
]
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": null,
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"def compute_weighted_values(attention_scores, V):\n",
151+
" return 0 # YOUR CODE HERE: compute the weighted values properly :)\n",
152+
"\n",
153+
"weighted_values = compute_weighted_values(attention_scores, V)\n",
154+
"print(f\"Weighted Values Shape: {weighted_values.shape}\\n\", weighted_values)"
155+
]
156+
},
157+
{
158+
"cell_type": "markdown",
159+
"metadata": {},
160+
"source": [
161+
"## Visualization of the attention scores\n",
162+
"\n",
163+
"Let's visualize the attention scores in the following. Of course they are all random, but you get an idea of their shape."
164+
]
165+
},
166+
{
167+
"cell_type": "code",
168+
"execution_count": null,
169+
"metadata": {},
170+
"outputs": [],
171+
"source": [
172+
"# Visualization of Attention Weights\n",
173+
"fig, ax = plt.subplots(figsize=(10, 6))\n",
174+
"cax = ax.matshow(attention_scores.detach().numpy(), cmap='viridis')\n",
175+
"plt.title(\"Attention Scores Heatmap\")\n",
176+
"plt.xticks(range(len(tokens)), tokens, rotation=90)\n",
177+
"plt.yticks(range(len(tokens)), tokens)\n",
178+
"fig.colorbar(cax)\n",
179+
"plt.show()"
180+
]
181+
},
182+
{
183+
"cell_type": "code",
184+
"execution_count": null,
185+
"metadata": {},
186+
"outputs": [],
187+
"source": []
188+
}
189+
],
190+
"metadata": {
191+
"kernelspec": {
192+
"display_name": "ml-exercise-pip",
193+
"language": "python",
194+
"name": "python3"
195+
},
196+
"language_info": {
197+
"codemirror_mode": {
198+
"name": "ipython",
199+
"version": 3
200+
},
201+
"file_extension": ".py",
202+
"mimetype": "text/x-python",
203+
"name": "python",
204+
"nbconvert_exporter": "python",
205+
"pygments_lexer": "ipython3",
206+
"version": "3.9.20"
207+
}
208+
},
209+
"nbformat": 4,
210+
"nbformat_minor": 2
211+
}

0 commit comments

Comments
 (0)