diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 11c27af5..8ff60dc8 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -156,7 +156,7 @@ def validate_remote_data_path(self): class Config(BaseConfig): # main config - name_model: Literal["debugmodel", "70M","150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M" + name_model: Literal["debugmodel", "70M","150M", "271M", "1B", "7B", "8B", "10B", "13B", "26B", "70B"] = "8B" type_model: Literal["llama2", "llama3"] = "llama3" # Project/Run diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 55ce25e8..e36c3342 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -9,7 +9,13 @@ from zeroband.config import Config from zeroband.models.llama.model import ModelArgs, Transformer - +from transformers import ( + AutoTokenizer, + LlamaConfig, + LlamaForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, +) __all__ = ["Transformer"] llama2_configs = { @@ -88,16 +94,7 @@ def get_model( ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" - if config.type_model == "llama2": - model_config = llama2_configs[config.name_model] - elif config.type_model == "llama3": - model_config = llama3_configs[config.name_model] - else: - raise ValueError(f"Model type {config.type_model} not supported") - - model_config.vocab_size = vocab_size - model_config.max_seq_len = config.data.seq_length - model_config.attn_fn = config.train.attn_fn - model_config.fused_linear_ce = config.train.fused_linear_ce - - return Transformer(model_config), model_config + config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flex_attention") + model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model) + + return model, config_model diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 06585bcc..fa7c7522 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -147,9 +147,9 @@ def train(config: Config): offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None - for layer_id, transformer_block in model.layers.items(): + for layer_id, transformer_block in enumerate(model.model.layers): if config.train.reshard_after_forward: - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + reshard_after_forward = int(layer_id) < len(model.model.layers) - 1 else: reshard_after_forward = False fully_shard( @@ -305,7 +305,7 @@ def train(config: Config): block_mask = batch["block_mask"] with sw.record_block("Run forward()"): - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + logits = model(input_ids=input_ids).logits.contiguous() flatten_logits = logits.reshape(-1, logits.size(-1)) # b seq vocab -> (b * seq) vocab flatten_labels = labels.reshape(-1) # b seq -> (b * seq) diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index fafa9c7b..eedbc2f7 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -48,11 +48,11 @@ def get_peak_flops(device_name: str) -> int: return 312e12 -def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: +def get_num_flop_per_token(num_params: int, model_config, seq_len: int) -> int: l, h, q, t = ( # noqa: E741 - model_config.n_layers, - model_config.n_heads, - model_config.dim // model_config.n_heads, + model_config.num_hidden_layers, + model_config.num_attention_heads, + model_config.hidden_size // model_config.num_attention_heads, seq_len, ) # Reasoning behind the factor of 12 for the self-attention part of the formula: @@ -66,10 +66,10 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: return flop_per_token -def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: +def get_num_params(model, exclude_embedding: bool = False) -> int: num_params = sum(p.numel() for p in model.parameters()) if exclude_embedding: - num_params -= model.tok_embeddings.weight.numel() + num_params -= model.lm_head.weight.numel() return num_params