-
Notifications
You must be signed in to change notification settings - Fork 651
Description
Bug description
When using torchtitan with huggingface checkpoint for in and output, there seems to be a incosistent wrong behaviour for the save hf model.
If you provide a hf model with missing model.safetensors.index.json you get a fair warning
WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format
and the saved final hf checkpoint behaves as expected. It saves a single hf safetensor (no matter if there were split safetensor files before) + index.json.
However, if the model.safetensors.index.json is not missing when loading the initial hf model, the saved final hf checkpoint seem to behave wrong. It keeps the safetensor splits for the output (based on the input model), but the model.safetensors.index.json is completely missing.
See the output here for a debug run with a single safetensor for the input and output model:

I think the issue can be found in this if condition:
The first one skips the step of save the model.safetensors.index.json and also the call of method consolidate_safetensors_files_on_every_rank is not saving it.

Can you have a look into that? Thank you! :)
Versions
- Tested on latest pytorch nightly: 2.10.0.dev20251210+cu128
- Use default debug model toml
torchtitan/models/llama3/train_configs/debug_model.tomlwithcheckpoint enabledand an adapted one where i pointed to a local downloaded model with tokenizer andmodel.safetensors.index.json, see
[model]
hf_assets_path = "<local-path>/llama3_debugmodel"
[checkpoint]
enable = true
folder = "checkpoint"
interval = 10
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
last_save_model_only = true
last_save_in_hf = true
initial_load_path = "<local-path>/llama3_debugmodel"
initial_load_in_hf = true
- Tested issue on torchtitan
run_train.shwith no modifications