|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.optim as optim |
| 4 | +from torch.utils.data import DataLoader |
| 5 | +from datasets import load_dataset |
| 6 | +from transformers import AutoTokenizer |
| 7 | +from jamba.model import Jamba |
| 8 | + |
| 9 | +# Load the dataset from Hugging Face |
| 10 | +dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") |
| 11 | + |
| 12 | +# Initialize the tokenizer |
| 13 | +tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True) |
| 14 | +tokenizer.pad_token = tokenizer.eos_token |
| 15 | + |
| 16 | + |
| 17 | +# Tokenize the dataset |
| 18 | +def tokenize_function(examples): |
| 19 | + return tokenizer( |
| 20 | + examples["text"], |
| 21 | + padding="max_length", |
| 22 | + truncation=True, |
| 23 | + max_length=100, |
| 24 | + return_tensors="pt", |
| 25 | + ) |
| 26 | + |
| 27 | + |
| 28 | +tokenized_datasets = dataset.map(tokenize_function, batched=True) |
| 29 | +tokenized_datasets.set_format(type="torch", columns=["input_ids"]) |
| 30 | + |
| 31 | + |
| 32 | +# DataLoader |
| 33 | +def collate_fn(batch): |
| 34 | + input_ids = torch.stack([item["input_ids"] for item in batch]) |
| 35 | + # Create targets by shifting input_ids one token to the left |
| 36 | + labels = torch.roll(input_ids, -1, dims=-1) |
| 37 | + return input_ids.squeeze(), labels.squeeze() |
| 38 | + |
| 39 | + |
| 40 | +dataloader = DataLoader( |
| 41 | + tokenized_datasets, |
| 42 | + batch_size=32, |
| 43 | + shuffle=True, |
| 44 | + collate_fn=collate_fn, |
| 45 | +) |
| 46 | + |
| 47 | +# Initialize the Jamba model with tokenizer's vocab size |
| 48 | +model = Jamba( |
| 49 | + dim=512, |
| 50 | + depth=6, |
| 51 | + num_tokens=tokenizer.vocab_size, |
| 52 | + d_state=256, |
| 53 | + d_conv=128, |
| 54 | + heads=8, |
| 55 | + num_experts=8, |
| 56 | + num_experts_per_token=2, |
| 57 | +) |
| 58 | + |
| 59 | +# Loss function and optimizer |
| 60 | +criterion = nn.CrossEntropyLoss() |
| 61 | +optimizer = optim.Adam(model.parameters(), lr=0.001) |
| 62 | + |
| 63 | +# Training loop |
| 64 | +epochs = 5 |
| 65 | +for epoch in range(epochs): |
| 66 | + for inputs, targets in dataloader: |
| 67 | + optimizer.zero_grad() # Zero the gradients |
| 68 | + |
| 69 | + # Forward pass |
| 70 | + outputs = model(inputs) |
| 71 | + loss = criterion( |
| 72 | + outputs.transpose(1, 2), targets |
| 73 | + ) # Adjust for cross-entropy expecting class dimension at dim=1 |
| 74 | + |
| 75 | + # Backward pass and optimize |
| 76 | + loss.backward() |
| 77 | + optimizer.step() |
| 78 | + |
| 79 | + print(f"Epoch {epoch+1}, Loss: {loss.item()}") |
| 80 | + |
| 81 | +print("Training complete!") |
0 commit comments