Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions megatron/model/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,13 @@ def __init__(self, dim, p=-1.0, eps=1e-8, bias=False):

def forward(self, x):
dtype = x.dtype
if self.p < 0.0 or self.p > 1.0:
norm_x = x.norm(2, dim=-1, keepdim=True)
d_x = self.d
else:
if self.p >= 0.0 and self.p <= 1.0:
partial_size = int(self.d * self.p)
partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

norm_x = partial_x.norm(2, dim=-1, keepdim=True)
d_x = partial_size
x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

rms_x = norm_x * d_x ** (-1.0 / 2)
x_normed = x / (rms_x + self.eps)
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)

if self.bias:
return self.scale * x_normed + self.offset
Expand Down
11 changes: 3 additions & 8 deletions tools/ckpts/convert_hf_llama_to_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,19 @@ def convert_model(hf_state_dict, hf_config, tp_ranks):
q = hf_state_dict[f"model.layers.{layer_num}.self_attn.q_proj.weight"]
k = hf_state_dict[f"model.layers.{layer_num}.self_attn.k_proj.weight"]
v = hf_state_dict[f"model.layers.{layer_num}.self_attn.v_proj.weight"]
# The GQA code splits the heads by the num_q_heads so we also do that
# here to ensure it matches...
q = q.view(num_q_heads, -1, q.shape[-1])
k = k.view(num_q_heads, -1, q.shape[-1])
v = v.view(num_q_heads, -1, q.shape[-1])

# Chunk for tensor parallelism...
for i, q_chunk, k_chunk, v_chunk in zip(
range(tp_ranks),
torch.chunk(q, tp_ranks, dim=0),
torch.chunk(k, tp_ranks, dim=0),
torch.chunk(v, tp_ranks, dim=0),
):
# Need to join the heads across q, k, v...
# The GQA code simply expects concatenated q,k,v weights for each tp partition
conv_state_dicts[i][
f"sequential.{layer_num+2}.attention.query_key_value.weight"
] = (
torch.cat([q_chunk, k_chunk, v_chunk], dim=1)
.view(-1, q.shape[-1])
torch.cat([q_chunk, k_chunk, v_chunk], dim=0)
.clone()
.detach()
)
Expand Down
86 changes: 22 additions & 64 deletions tools/ckpts/convert_neox_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,75 +365,33 @@ def reshard_and_split_qkv(
), "Must map QKV to precisely 3 resulting weight matrices."

for key, hf_keys in param_mapping.items():
# we first merge the QKV proj. across TP ranks
sharded_qkv = torch.stack(
# We first merge the QKV proj. across TP ranks
tp_sharded_qkv = torch.stack(
get_state(loaded_tp_ranks, key, layer_idx, sequential), dim=0
)
# should now have shape [TP_SIZE, (hidden_size + 2 * kv_hidden_size) / TP_SIZE, hidden_size].

sharded_qkv = sharded_qkv.view(
len(loaded_tp_ranks),
hf_config.num_attention_heads // len(loaded_tp_ranks),
int(
hf_config.hidden_size
// hf_config.num_attention_heads
* (
1
+ 2 * hf_config.num_key_value_heads / hf_config.num_attention_heads
)
),
hf_config.hidden_size,
) # is meant to convert to shape [TP_SIZE, NUM_QUERY_HEADS_PER_SHARD, dims_per_head * (1 + 2 * kv-to-q head ratio), hidden_size]

# We should now have shape [TP_SIZE, (hidden_size + 2 * kv_hidden_size) / TP_SIZE, hidden_size].
# At this point, for each TP rank, q, k, and v are concatenated

# Next, we split tp_harded_qkv into q, k, v along dim 1
hidden_size_per_attention_head = hf_config.hidden_size // hf_config.num_attention_heads
kv_hidden_size = int(hidden_size_per_attention_head * hf_config.num_key_value_heads)
tensor_parallel_size = len(loaded_tp_ranks)

q, k, v = torch.split(
sharded_qkv,
tp_sharded_qkv,
[
hf_config.hidden_size // hf_config.num_attention_heads,
int(
(hf_config.num_key_value_heads / hf_config.num_attention_heads)
* hf_config.hidden_size
// hf_config.num_attention_heads
),
int(
(hf_config.num_key_value_heads / hf_config.num_attention_heads)
* hf_config.hidden_size
// hf_config.num_attention_heads
),
hf_config.hidden_size // tensor_parallel_size,
kv_hidden_size // tensor_parallel_size,
kv_hidden_size // tensor_parallel_size,
],
dim=2,
)
# splits along the (dims_per_head * (1 + 2 * kv-to-q head ratio)_ dim to get 3 tensors:
# 1 x [TP_SIZE, NUM_Q_HEADS_PER_SHARD, dims_per_head, hidden_size] and 2 x [TP_SIZE, NUM_Q_HEADS_PER_SHARD, (dims_per_head / kv-to-q head ratio), hidden_size]
# these are the Q, and K, V tensors respectively.

# we have to do additional reshape for each individual tensor now,
# into the expected square (or smaller than square, for K/V tensors) shape
q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2)
q = q.view(
hf_config.num_attention_heads,
hf_config.hidden_size // hf_config.num_attention_heads,
hf_config.hidden_size,
).reshape(hf_config.hidden_size, hf_config.hidden_size)
k = k.reshape(
hf_config.num_key_value_heads,
hf_config.hidden_size // hf_config.num_attention_heads,
hf_config.hidden_size,
).reshape(
hf_config.hidden_size
// hf_config.num_attention_heads
* hf_config.num_key_value_heads,
hf_config.hidden_size,
)
v = v.reshape(
hf_config.num_key_value_heads,
hf_config.hidden_size // hf_config.num_attention_heads,
hf_config.hidden_size,
).reshape(
hf_config.hidden_size
// hf_config.num_attention_heads
* hf_config.num_key_value_heads,
hf_config.hidden_size,
)
dim=1,
) # New shapes:
# q-->[TP_SIZE, hidden_size/TP_SIZE, hidden_size]
# k-->[TP_SIZE, kv_hidden_size/TP_SIZE, hidden_size]
# v-->[TP_SIZE, kv_hidden_size/TP_SIZE, hidden_size]

# Finally, we flatten the first two dimensions merging the TP partitions
q, k, v = q.reshape(-1, q.shape[2]), k.reshape(-1, k.shape[2]), v.reshape(-1, k.shape[2])

# return these
state_dict = {}
Expand Down
Loading