diff --git a/xtuner/utils/zero_to_any_dtype.py b/xtuner/utils/zero_to_any_dtype.py index 13eac61f5..e68f4a3be 100644 --- a/xtuner/utils/zero_to_any_dtype.py +++ b/xtuner/utils/zero_to_any_dtype.py @@ -116,7 +116,7 @@ def get_model_state_files(checkpoint_dir): def parse_model_states(files, dtype=DEFAULT_DTYPE): zero_model_states = [] for file in files: - state_dict = torch.load(file, map_location=device) + state_dict = torch.load(file, map_location=device, weights_only=False) if BUFFER_NAMES not in state_dict: raise ValueError(f"{file} is not a model state checkpoint") @@ -169,7 +169,7 @@ def parse_optim_states(files, ds_checkpoint_dir, dtype=DEFAULT_DTYPE): total_files = len(files) flat_groups = [] for f in tqdm(files, desc="Load Checkpoints"): - state_dict = torch.load(f, map_location=device) + state_dict = torch.load(f, map_location=device, weights_only=False) if ZERO_STAGE not in state_dict[OPTIMIZER_STATE_DICT]: raise ValueError(f"{f} is not a zero checkpoint")