We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3c69c8d commit 33f57b2Copy full SHA for 33f57b2
1 file changed
src/parallax/server/shard_loader.py
@@ -223,7 +223,12 @@ def load(
223
is_needed = True
224
remapped_key = key.replace("model.", "", 1)
225
if model_shard.is_last_shard and config.get("tie_word_embeddings", False):
226
- shard_weights["lm_head.weight"] = mx.array(f.get_tensor(key))
+ if key == "model.embed_tokens.weight":
227
+ shard_weights["lm_head.weight"] = mx.array(f.get_tensor(key))
228
+ elif key == "model.embed_tokens.scales":
229
+ shard_weights["lm_head.scales"] = mx.array(f.get_tensor(key))
230
+ elif key == "model.embed_tokens.biases":
231
+ shard_weights["lm_head.biases"] = mx.array(f.get_tensor(key))
232
elif model_shard.is_last_shard:
233
if "model.norm" in key:
234
0 commit comments