Skip to content

Commit ab4d46b

Browse files
author
Kye
committed
[TRAINER]
1 parent d198b62 commit ab4d46b

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,9 @@ print(output)
4949

5050
```
5151

52+
## Train
53+
`python3 train.py`
54+
55+
5256
# License
5357
MIT

train.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.utils.data import DataLoader
5+
from datasets import load_dataset
6+
from transformers import AutoTokenizer
7+
from jamba.model import Jamba
8+
9+
# Load the dataset from Hugging Face
10+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
11+
12+
# Initialize the tokenizer
13+
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
14+
tokenizer.pad_token = tokenizer.eos_token
15+
16+
17+
# Tokenize the dataset
18+
def tokenize_function(examples):
19+
return tokenizer(
20+
examples["text"],
21+
padding="max_length",
22+
truncation=True,
23+
max_length=100,
24+
return_tensors="pt",
25+
)
26+
27+
28+
tokenized_datasets = dataset.map(tokenize_function, batched=True)
29+
tokenized_datasets.set_format(type="torch", columns=["input_ids"])
30+
31+
32+
# DataLoader
33+
def collate_fn(batch):
34+
input_ids = torch.stack([item["input_ids"] for item in batch])
35+
# Create targets by shifting input_ids one token to the left
36+
labels = torch.roll(input_ids, -1, dims=-1)
37+
return input_ids.squeeze(), labels.squeeze()
38+
39+
40+
dataloader = DataLoader(
41+
tokenized_datasets,
42+
batch_size=32,
43+
shuffle=True,
44+
collate_fn=collate_fn,
45+
)
46+
47+
# Initialize the Jamba model with tokenizer's vocab size
48+
model = Jamba(
49+
dim=512,
50+
depth=6,
51+
num_tokens=tokenizer.vocab_size,
52+
d_state=256,
53+
d_conv=128,
54+
heads=8,
55+
num_experts=8,
56+
num_experts_per_token=2,
57+
)
58+
59+
# Loss function and optimizer
60+
criterion = nn.CrossEntropyLoss()
61+
optimizer = optim.Adam(model.parameters(), lr=0.001)
62+
63+
# Training loop
64+
epochs = 5
65+
for epoch in range(epochs):
66+
for inputs, targets in dataloader:
67+
optimizer.zero_grad() # Zero the gradients
68+
69+
# Forward pass
70+
outputs = model(inputs)
71+
loss = criterion(
72+
outputs.transpose(1, 2), targets
73+
) # Adjust for cross-entropy expecting class dimension at dim=1
74+
75+
# Backward pass and optimize
76+
loss.backward()
77+
optimizer.step()
78+
79+
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
80+
81+
print("Training complete!")

0 commit comments

Comments
 (0)