-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
354 lines (310 loc) · 13.6 KB
/
train.py
File metadata and controls
354 lines (310 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import torch
import torch.nn as nn
from torch.nn import functional as F
import dataclasses
from typing import Optional
import re
from typing import Any, List, Sequence, Tuple, Union
import os
import requests
import wandb
import time
from sentencepiece import SentencePieceProcessor
wandb.init(project="tiny_gemma", name="urdu_training_run")
def load_dataset():
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
print("First 200 characters:")
print(text[:200])
chars = sorted(list(set(text)))
v = len(chars)
print(f'\nVocabulary: {chars}')
print(f'Vocabulary size: {v}')
return text, chars, v
class CharacterTokenizer:
def __init__(self, chars: List[str]):
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for i, ch in enumerate(chars)}
def encode(self, s: str) -> List[int]:
return [self.stoi[c] for c in s]
def decode(self, t: List[int]) -> str:
return ''.join([self.itos[i] for i in t])
class OriginalGemmaTokenizer:
def __init__(self, model_path: Optional[str]):
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
return self.sp_model.decode(t)
@dataclasses.dataclass
class GemmaConfig:
vocab_size: int = 65
max_position_embeddings: int = 256
num_hidden_layers: int = 4
num_attention_heads: int = 4
num_key_value_heads: int = 1
hidden_size: int = 128
intermediate_size: int = 512
head_dim: int = 32
rms_norm_eps: float = 1e-6
tokenizer: Optional[str] = None
rope_theta = 100.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_config_for_7b() -> GemmaConfig:
return GemmaConfig(
vocab_size = 256128,
max_position_embeddings = 8192,
num_hidden_layers = 28,
num_attention_heads = 16,
num_key_value_heads = 16,
hidden_size = 3072,
intermediate_size = 24576,
head_dim = 256,
tokenizer = 'tokenizer/tokenizer.model',
rope_theta = 10000.0
)
def get_config_for_2b() -> GemmaConfig:
return GemmaConfig(
vocab_size = 256128,
max_position_embeddings = 8192,
num_hidden_layers = 18,
num_attention_heads = 8,
num_key_value_heads = 1,
hidden_size = 2048,
intermediate_size = 16384,
head_dim = 256,
tokenizer = 'tokenizer/tokenizer.model',
rope_theta = 10000.0
)
def download_original_tokenizer():
DESTINATION_FOLDER_PATH = './tokenizer'
FILE_NAME = 'tokenizer.model'
local_file_path = os.path.join(DESTINATION_FOLDER_PATH, FILE_NAME)
if os.path.exists(local_file_path):
print(f'File already exists')
return
url = f'https://raw.githubusercontent.com/google/gemma_pytorch/main/tokenizer/{FILE_NAME}'
response = requests.get(url)
if response.status_code == 200:
os.makedirs(DESTINATION_FOLDER_PATH, exist_ok=True)
with open(local_file_path, 'wb') as file:
file.write(response.content)
print(f'File successfully downloaded to {local_file_path}')
else:
print(f'Failed to download the file. HTTP status code: {response.status_code}')
def get_model_config(variant: str = None, vocab_size: int = 65):
if variant == '7b':
download_original_tokenizer()
config = get_config_for_7b()
tokenizer = OriginalGemmaTokenizer('tokenizer/tokenizer.model')
return config, tokenizer
elif variant == '2b':
download_original_tokenizer()
config = get_config_for_2b()
tokenizer = OriginalGemmaTokenizer('tokenizer/tokenizer.model')
return config, tokenizer
else:
config = GemmaConfig()
config.vocab_size = vocab_size
text, chars, v = load_dataset()
tokenizer = CharacterTokenizer(chars)
return config, tokenizer, text
def apply_rotary_emb(x: torch.Tensor, dim: int, theta: float = 10000.0) -> torch.Tensor:
seq_len = x.size(1)
device = x.device
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(seq_len, device=device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis.unsqueeze(0)).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)
return x_out
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = True):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
x = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
output = x * (1 + self.weight)
else:
output = x * self.weight
return output
class GemmaMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size)
self.up_proj = nn.Linear(hidden_size, intermediate_size)
self.down_proj = nn.Linear(intermediate_size, hidden_size)
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate)
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs
class GemmaAttention(nn.Module):
def __init__(self, config: GemmaConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = config.hidden_size
self.head_dim = config.head_dim
self.theta = config.rope_theta
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = nn.Linear(self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
mask_negatives = torch.full((1, 1, config.max_position_embeddings, config.max_position_embeddings),
-2.3819763e38).to(torch.float)
mask = torch.triu(mask_negatives, diagonal=1).to(config.device)
self.register_buffer('mask', mask)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, input_len, _ = hidden_states.shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xq = apply_rotary_emb(xq, self.head_dim, self.theta)
xk = apply_rotary_emb(xk, self.head_dim, self.theta)
if self.num_kv_heads != self.num_heads:
xk = torch.repeat_interleave(xk, self.num_queries_per_kv, dim=2)
xv = torch.repeat_interleave(xv, self.num_queries_per_kv, dim=2)
q = xq.transpose(1, 2)
k = xk.transpose(1, 2)
v = xv.transpose(1, 2)
scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
scores = scores + self.mask[...,:input_len, :input_len]
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v)
output = output.transpose(1, 2).contiguous().view(batch_size, input_len, -1)
output = self.o_proj(output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(self, config: GemmaConfig):
super().__init__()
self.self_attn = GemmaAttention(config)
self.mlp = GemmaMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Gemma(nn.Module):
def __init__(self, config: GemmaConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList(GemmaDecoderLayer(config) for _ in range(config.num_hidden_layers))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
hidden_states = layer(hidden_states=hidden_states)
hidden_states = self.norm(hidden_states)
return hidden_states
class tinyGemma(nn.Module):
def __init__(self, config: GemmaConfig, tokenizer):
super().__init__()
self.config = config
assert config.hidden_size % config.num_attention_heads == 0
self.max_seq_len = config.max_position_embeddings
self.head_dim = config.head_dim
self.vocab_size = config.vocab_size
self.tokenizer = tokenizer
self.embedder = nn.Embedding(self.vocab_size, config.hidden_size)
self.model = Gemma(config)
self.criterion = nn.CrossEntropyLoss()
def forward(self, input_token_ids: torch.Tensor, target_token_ids: torch.Tensor = None) -> torch.Tensor:
hidden_states = self.embedder(input_token_ids)
hidden_states = hidden_states * (self.config.hidden_size**0.5)
hidden_states = self.model(hidden_states=hidden_states)
logits = hidden_states @ self.embedder.weight.t()
if target_token_ids is not None:
loss = self.criterion(logits.view(-1, logits.size(-1)), target_token_ids.view(-1))
return logits, loss
else:
return logits
def get_batch(split, batch_size, train_data, val_data, config):
data = train_data if split=='train' else val_data
ix = torch.randint(len(data) - config.max_position_embeddings, (batch_size,))
x = torch.stack([torch.tensor([data[i+j] for j in range(config.max_position_embeddings)]) for i in ix])
y = torch.stack([torch.tensor([data[i+j+1] for j in range(config.max_position_embeddings)]) for i in ix])
return x.to(config.device), y.to(config.device)
def estimate_loss(model, batch_size, train_data, val_data, config, eval_iters=10):
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split, batch_size, train_data, val_data, config)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = {
"loss": losses.mean(),
"ppl": torch.exp(losses.mean())
}
model.train()
return out
def save_model(model, config, path="best_model.pt"):
torch.save(model.state_dict(), path)
config, tokenizer, text = get_model_config()
train_data = [tokenizer.stoi[c] for c in text[:-1000]]
val_data = [tokenizer.stoi[c] for c in text[-1000:]]
model = tinyGemma(config, tokenizer).to(config.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
max_iters = 5000
eval_interval = 10
batch_size = 16
accum_steps = 2
scaler = torch.cuda.amp.GradScaler()
best_val_loss = float('inf')
for iter in range(max_iters):
xb, yb = get_batch('train', batch_size, train_data, val_data, config)
with torch.cuda.amp.autocast():
logits, loss = model(xb, yb)
(loss / accum_steps).backward()
if (iter + 1) % accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
if iter % eval_interval == 0:
losses = estimate_loss(model, batch_size, train_data, val_data, config)
val_loss = losses['val']['loss']
print(f"Iter {iter}: train_loss={losses['train']['loss']:.4f}, val_loss={val_loss:.4f}, val_ppl={losses['val']['ppl']:.4f}")
wandb.log({
"train_loss": losses['train']['loss'],
"val_loss": val_loss,
"val_ppl": losses['val']['ppl']
})
if val_loss < best_val_loss:
best_val_loss = val_loss
save_model(model, config, path="best_model.pt")