From 42d7f268845c930b5a665857d121d1e81d3c6dfe Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 01:29:24 +0000 Subject: [PATCH 1/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/models/llama/__init__.py | 42 +++++++++++++++++++-------- src/zeroband/train.py | 4 +-- src/zeroband/utils/__init__.py | 12 ++++---- 3 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 55ce25e8..1495ac26 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -9,7 +9,13 @@ from zeroband.config import Config from zeroband.models.llama.model import ModelArgs, Transformer - +from transformers import ( + AutoTokenizer, + LlamaConfig, + LlamaForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, +) __all__ = ["Transformer"] llama2_configs = { @@ -88,16 +94,28 @@ def get_model( ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" - if config.type_model == "llama2": - model_config = llama2_configs[config.name_model] - elif config.type_model == "llama3": - model_config = llama3_configs[config.name_model] - else: - raise ValueError(f"Model type {config.type_model} not supported") + config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flex_attention") + model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model) + + return model, config_model + +name_to_hf_tokenizer = { + "debugmodel": "mistralai/Mistral-7B-v0.1", + "150M": "mistralai/Mistral-7B-v0.1", + "1B": "mistralai/Mistral-7B-v0.1", + "Qwen1.5B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "Qwen7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + "Qwen32B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "Llama8B": "meta-llama/Meta-Llama-3-8B", +} - model_config.vocab_size = vocab_size - model_config.max_seq_len = config.data.seq_length - model_config.attn_fn = config.train.attn_fn - model_config.fused_linear_ce = config.train.fused_linear_ce +name_to_class = { + "debugmodel": (LlamaConfig, LlamaForCausalLM), + "150M": (LlamaConfig, LlamaForCausalLM), + "1B": (LlamaConfig, LlamaForCausalLM), + "Qwen1.5B": (Qwen2Config, Qwen2ForCausalLM), + "Qwen7B": (Qwen2Config, Qwen2ForCausalLM), + "Qwen32B": (Qwen2Config, Qwen2ForCausalLM), + "Llama8B": (LlamaConfig, LlamaForCausalLM), +} - return Transformer(model_config), model_config diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 06585bcc..df26b84d 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -147,9 +147,9 @@ def train(config: Config): offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None - for layer_id, transformer_block in model.layers.items(): + for layer_id, transformer_block in enumerate(model.model.layers): if config.train.reshard_after_forward: - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + reshard_after_forward = int(layer_id) < len(model.model.layers) - 1 else: reshard_after_forward = False fully_shard( diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index fafa9c7b..eedbc2f7 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -48,11 +48,11 @@ def get_peak_flops(device_name: str) -> int: return 312e12 -def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: +def get_num_flop_per_token(num_params: int, model_config, seq_len: int) -> int: l, h, q, t = ( # noqa: E741 - model_config.n_layers, - model_config.n_heads, - model_config.dim // model_config.n_heads, + model_config.num_hidden_layers, + model_config.num_attention_heads, + model_config.hidden_size // model_config.num_attention_heads, seq_len, ) # Reasoning behind the factor of 12 for the self-attention part of the formula: @@ -66,10 +66,10 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: return flop_per_token -def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: +def get_num_params(model, exclude_embedding: bool = False) -> int: num_params = sum(p.numel() for p in model.parameters()) if exclude_embedding: - num_params -= model.tok_embeddings.weight.numel() + num_params -= model.lm_head.weight.numel() return num_params From f436ef7734ff9f0dc47e795467771185456f1e06 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 01:31:37 +0000 Subject: [PATCH 2/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index df26b84d..fa7c7522 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -305,7 +305,7 @@ def train(config: Config): block_mask = batch["block_mask"] with sw.record_block("Run forward()"): - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + logits = model(input_ids=input_ids).logits.contiguous() flatten_logits = logits.reshape(-1, logits.size(-1)) # b seq vocab -> (b * seq) vocab flatten_labels = labels.reshape(-1) # b seq -> (b * seq) From b4c3f0e5bf198ea906973b572fd9c70dba66e8cf Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 01:39:33 +0000 Subject: [PATCH 3/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/models/llama/__init__.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 1495ac26..e36c3342 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -98,24 +98,3 @@ def get_model( model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model) return model, config_model - -name_to_hf_tokenizer = { - "debugmodel": "mistralai/Mistral-7B-v0.1", - "150M": "mistralai/Mistral-7B-v0.1", - "1B": "mistralai/Mistral-7B-v0.1", - "Qwen1.5B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - "Qwen7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", - "Qwen32B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", - "Llama8B": "meta-llama/Meta-Llama-3-8B", -} - -name_to_class = { - "debugmodel": (LlamaConfig, LlamaForCausalLM), - "150M": (LlamaConfig, LlamaForCausalLM), - "1B": (LlamaConfig, LlamaForCausalLM), - "Qwen1.5B": (Qwen2Config, Qwen2ForCausalLM), - "Qwen7B": (Qwen2Config, Qwen2ForCausalLM), - "Qwen32B": (Qwen2Config, Qwen2ForCausalLM), - "Llama8B": (LlamaConfig, LlamaForCausalLM), -} - From 72a30ba791f608cac288b74ce88ae3bf6847e732 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 01:48:35 +0000 Subject: [PATCH 4/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/models/llama/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index e36c3342..4c77d68e 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -94,7 +94,7 @@ def get_model( ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" - config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flex_attention") + config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="sdpa") model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model) return model, config_model From edc6408b299f261785b4ee4be3739912053af460 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 02:02:44 +0000 Subject: [PATCH 5/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/models/llama/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 4c77d68e..a12e9948 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -94,7 +94,7 @@ def get_model( ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" - config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="sdpa") + config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flash_attention_2") model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model) return model, config_model From aeef728ac632d63cbc1490389fbb4c2df77fff7e Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 02:10:32 +0000 Subject: [PATCH 6/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 11c27af5..8ff60dc8 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -156,7 +156,7 @@ def validate_remote_data_path(self): class Config(BaseConfig): # main config - name_model: Literal["debugmodel", "70M","150M", "271M", "1B", "7B", "10B", "13B", "26B", "70B"] = "150M" + name_model: Literal["debugmodel", "70M","150M", "271M", "1B", "7B", "8B", "10B", "13B", "26B", "70B"] = "8B" type_model: Literal["llama2", "llama3"] = "llama3" # Project/Run From 30018e814149f3c8807891201090209910e5fa83 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sat, 1 Mar 2025 02:18:56 +0000 Subject: [PATCH 7/7] use transofrmers Signed-off-by: Sami Jaghouar --- src/zeroband/models/llama/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index a12e9948..e36c3342 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -94,7 +94,7 @@ def get_model( ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" - config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flash_attention_2") + config_model = LlamaConfig.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="flex_attention") model = LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path="meta-llama/Meta-Llama-3-8B", config=config_model) return model, config_model