Skip to content

Commit d46bb45

Browse files
committed
quick fix conversion
1 parent bc1f5a6 commit d46bb45

File tree

4 files changed

+42
-33
lines changed

4 files changed

+42
-33
lines changed

examples/llama/convert_weights.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def get_config_mapping(nt_to_hf: bool = True) -> dict[str, str]:
8484
"use_cache": "use_cache",
8585
"vocab_size": "vocab_size",
8686
"attention_bias": "attention_bias",
87+
"rope_interleaved": "rope_interleaved",
8788
}
8889
if nt_to_hf:
8990
return {nt: hf for hf, nt in hf_to_nt_map.items()}

examples/llama/tests/test_conversion.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# ruff: noqa: E402
22
import dataclasses
33
import json
4+
import sys
45
from pathlib import Path
56
from typing import Optional
67

8+
_this_file = Path(__file__).resolve()
9+
_nanotron_root = _this_file.parent.parent.parent.parent # tests -> llama -> examples -> nanotron
10+
sys.path.insert(0, str(_nanotron_root))
11+
712
import pytest
813
import torch
914
from transformers import AutoModelForCausalLM, LlamaForCausalLM
@@ -12,6 +17,7 @@
1217
set_system_path()
1318

1419
import nanotron
20+
1521
from nanotron import distributed as dist
1622
from nanotron.config import LlamaConfig as NanotronLlamaConfig
1723
from nanotron.config import NanotronConfigs
@@ -84,7 +90,7 @@
8490
"tie_word_embeddings": False,
8591
"use_cache": True,
8692
"vocab_size": 4096,
87-
"_attn_implementation": "sdpa",
93+
"_attn_implementation": "flash_attention_2",
8894
"attention_bias": False,
8995
"rope_interleaved": False,
9096
}
@@ -117,8 +123,7 @@ def create_huggingface_model(model_name: Optional[str] = None) -> LlamaForCausal
117123
with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16):
118124
model_hf = LlamaForCausalLM._from_config(get_hf_config(CONFIG))
119125
else:
120-
with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16):
121-
model_hf = AutoModelForCausalLM.from_pretrained(model_name)
126+
model_hf = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).cuda()
122127
return model_hf
123128

124129

@@ -336,13 +341,13 @@ def test_tensor_parallel_conversion(input_ids: torch.Tensor):
336341
# run all tests
337342
# test_nt_to_hf(input_ids=torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda"))
338343
# test_hf_to_nt(input_ids=torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda"))
339-
test_tensor_parallel_conversion(
340-
input_ids=torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda")
341-
)
344+
# test_tensor_parallel_conversion(
345+
# input_ids=torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda")
346+
# )
342347

343348
# Warning: Converting from HF to Nanotron is a better test because we don't initialize weights in standard way. (e.g. Layernorms)
344349
# Test SmolLM2-135M
345-
# test_hf_to_nt(
346-
# input_ids=torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda"),
347-
# model_name="HuggingFaceTB/SmolLM2-135M",
348-
# )
350+
test_hf_to_nt(
351+
input_ids=torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda"),
352+
model_name="HuggingFaceTB/SmolLM2-135M",
353+
)

src/nanotron/models/qwen.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def __init__(
208208
max_seq_len=config.max_position_embeddings,
209209
base=config.rope_theta,
210210
interleaved=config.rope_interleaved,
211-
seq_len_scaling_factor=config.rope_seq_len_scaling_factor,
211+
seq_len_scaling_factor=config.rope_seq_len_interpolation_factor,
212212
fused=config._fused_rotary_emb,
213213
)
214214
self.attention = CoreAttention(config, tp_pg, cp_pg, layer_idx)
@@ -238,28 +238,28 @@ def forward(
238238

239239
if self._use_qkv_packed:
240240
attn_output = self._forward_packed(qkv, seq_length, position_ids, cu_seqlens)
241-
# else:
242-
# q, k, v = qkv.split(
243-
# [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1
244-
# ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size]
245-
# q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim]
246-
# k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim]
247-
# v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim]
248-
# if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0:
249-
# rotary_pos_emb = self.rotary_emb(
250-
# position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length
251-
# ) # [b*s, dim] or [seq_length, dim]
252-
# q = self.rotary_emb.apply_rotary_pos_emb(
253-
# q, rotary_pos_emb, seq_length=seq_length
254-
# ) # [b*s, num_heads, head_dim]
255-
# k = self.rotary_emb.apply_rotary_pos_emb(
256-
# k, rotary_pos_emb, seq_length=seq_length
257-
# ) # [b*s, num_kv_heads, head_dim]
258-
# else:
259-
# log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0)
260-
# attn_output = self.attention(
261-
# q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens
262-
# )
241+
else:
242+
q, k, v = qkv.split(
243+
[self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1
244+
) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size]
245+
q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim]
246+
k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim]
247+
v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim]
248+
if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0:
249+
rotary_pos_emb = self.rotary_emb(
250+
position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length
251+
) # [b*s, dim] or [seq_length, dim]
252+
q = self.rotary_emb.apply_rotary_pos_emb(
253+
q, rotary_pos_emb, seq_length=seq_length
254+
) # [b*s, num_heads, head_dim]
255+
k = self.rotary_emb.apply_rotary_pos_emb(
256+
k, rotary_pos_emb, seq_length=seq_length
257+
) # [b*s, num_kv_heads, head_dim]
258+
else:
259+
log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0)
260+
attn_output = self.attention(
261+
q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens
262+
)
263263
output = self.o_proj(attn_output)
264264
# Return original position_ids shape
265265
return {"hidden_states": output, "position_ids": position_ids.view(-1, seq_length)}

src/nanotron/nn/rotary.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def apply_rotary_pos_emb(self, tensor, freqs, multi_latent_attention=False, msca
141141
self.rotate_half(rotary_part) * self.sin_values.unsqueeze(1)
142142
)
143143

144+
# Reshape back to [b*s, nheads, dim]
145+
rotated_tensor = rotated_tensor.view(-1, rotated_tensor.shape[2], rotated_tensor.shape[3])
146+
144147
# Concatenate with the pass-through part (if any)
145148
if pass_through_part is not None and pass_through_part.shape[-1] > 0:
146149
return torch.cat((rotated_tensor, pass_through_part), dim=-1)

0 commit comments

Comments
 (0)