Skip to content

debug transformer #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def validate_remote_data_path(self):

class Config(BaseConfig):
# main config
name_model: Literal["debugmodel", "70M","150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M"
name_model: Literal["debugmodel", "70M","150M", "271M", "1B", "7B", "8B", "10B", "13B", "26B", "70B"] = "8B"
type_model: Literal["llama2", "llama3"] = "llama3"

# Project/Run
Expand Down
25 changes: 11 additions & 14 deletions src/zeroband/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@

from zeroband.config import Config
from zeroband.models.llama.model import ModelArgs, Transformer

from transformers import (
AutoTokenizer,

Check failure on line 13 in src/zeroband/models/llama/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/models/llama/__init__.py:13:5: F401 `transformers.AutoTokenizer` imported but unused
LlamaConfig,
LlamaForCausalLM,
Qwen2Config,

Check failure on line 16 in src/zeroband/models/llama/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/models/llama/__init__.py:16:5: F401 `transformers.Qwen2Config` imported but unused
Qwen2ForCausalLM,

Check failure on line 17 in src/zeroband/models/llama/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/zeroband/models/llama/__init__.py:17:5: F401 `transformers.Qwen2ForCausalLM` imported but unused
)
__all__ = ["Transformer"]

llama2_configs = {
Expand Down Expand Up @@ -88,16 +94,7 @@
) -> tuple[Transformer, ModelArgs]:
"""get the transformer model"""

if config.type_model == "llama2":
model_config = llama2_configs[config.name_model]
elif config.type_model == "llama3":
model_config = llama3_configs[config.name_model]
else:
raise ValueError(f"Model type {config.type_model} not supported")

model_config.vocab_size = vocab_size
model_config.max_seq_len = config.data.seq_length
model_config.attn_fn = config.train.attn_fn
model_config.fused_linear_ce = config.train.fused_linear_ce

return Transformer(model_config), model_config
config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flex_attention")
model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model)

return model, config_model
6 changes: 3 additions & 3 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@

offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None

for layer_id, transformer_block in model.layers.items():
for layer_id, transformer_block in enumerate(model.model.layers):
if config.train.reshard_after_forward:
reshard_after_forward = int(layer_id) < len(model.layers) - 1
reshard_after_forward = int(layer_id) < len(model.model.layers) - 1
else:
reshard_after_forward = False
fully_shard(
Expand Down Expand Up @@ -302,10 +302,10 @@
batch = next(train_dataloader_iterator)
input_ids = batch["input_ids"]
labels = batch["labels"]
block_mask = batch["block_mask"]

Check failure on line 305 in src/zeroband/train.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F841)

src/zeroband/train.py:305:25: F841 Local variable `block_mask` is assigned to but never used

with sw.record_block("Run forward()"):
logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
logits = model(input_ids=input_ids).logits.contiguous()
flatten_logits = logits.reshape(-1, logits.size(-1)) # b seq vocab -> (b * seq) vocab
flatten_labels = labels.reshape(-1) # b seq -> (b * seq)

Expand Down
12 changes: 6 additions & 6 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def get_peak_flops(device_name: str) -> int:
return 312e12


def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
def get_num_flop_per_token(num_params: int, model_config, seq_len: int) -> int:
l, h, q, t = ( # noqa: E741
model_config.n_layers,
model_config.n_heads,
model_config.dim // model_config.n_heads,
model_config.num_hidden_layers,
model_config.num_attention_heads,
model_config.hidden_size // model_config.num_attention_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
Expand All @@ -66,10 +66,10 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
return flop_per_token


def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
def get_num_params(model, exclude_embedding: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
if exclude_embedding:
num_params -= model.tok_embeddings.weight.numel()
num_params -= model.lm_head.weight.numel()
return num_params


Expand Down
Loading