Skip to content

Fix PI0 embed_tokens.weight loading + WandB list metric logging#3116

Open
he-yufeng wants to merge 1 commit intohuggingface:mainfrom
he-yufeng:fix/pi0-embed-tokens-wandb-list
Open

Fix PI0 embed_tokens.weight loading + WandB list metric logging#3116
he-yufeng wants to merge 1 commit intohuggingface:mainfrom
he-yufeng:fix/pi0-embed-tokens-wandb-list

Conversation

@he-yufeng
Copy link

What this PR does

Fixes two issues reported in #3109:

  1. PI0 state dict embed_tokens.weight missing after loading: _fix_pytorch_state_dict_keys only creates this key from lm_head.weight (weight tying), but some checkpoints store embed_tokens.weight directly. Added explicit handling for this key.

  2. WandB silently drops loss_per_dim: the logging wrapper only handles scalars. List/tuple metrics like loss_per_dim are now expanded into per-index entries (loss_per_dim_0, loss_per_dim_1, ...).

Fixes #3109

Two fixes:

1. PI0 state dict remapping: handle checkpoints that store
   embed_tokens.weight directly (not only via lm_head.weight tying).
   Without this, the key is missing after load_state_dict.

2. WandB wrapper: log list/tuple metrics (like loss_per_dim) as
   individual per-index entries instead of silently dropping them.

Fixes huggingface#3109
@github-actions github-actions bot added the policies Items related to robot policies label Mar 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

policies Items related to robot policies

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PI0 WARNING: Vision embedding key might need handling && WandB logging of key "loss_per_dim" was ignored

1 participant