Skip to content

Commit 33f57b2

Browse files
committed
fix bug when tie_word_embeddings is true
1 parent 3c69c8d commit 33f57b2

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

src/parallax/server/shard_loader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,12 @@ def load(
223223
is_needed = True
224224
remapped_key = key.replace("model.", "", 1)
225225
if model_shard.is_last_shard and config.get("tie_word_embeddings", False):
226-
shard_weights["lm_head.weight"] = mx.array(f.get_tensor(key))
226+
if key == "model.embed_tokens.weight":
227+
shard_weights["lm_head.weight"] = mx.array(f.get_tensor(key))
228+
elif key == "model.embed_tokens.scales":
229+
shard_weights["lm_head.scales"] = mx.array(f.get_tensor(key))
230+
elif key == "model.embed_tokens.biases":
231+
shard_weights["lm_head.biases"] = mx.array(f.get_tensor(key))
227232
elif model_shard.is_last_shard:
228233
if "model.norm" in key:
229234
is_needed = True

0 commit comments

Comments
 (0)