diff --git a/fastchat/model/apply_delta.py b/fastchat/model/apply_delta.py index ba1c06d48..d22159b1c 100644 --- a/fastchat/model/apply_delta.py +++ b/fastchat/model/apply_delta.py @@ -34,7 +34,7 @@ def split_files(model_path, tmp_path, split_size): part = 0 try: for file_path in tqdm(files): - state_dict = torch.load(file_path) + state_dict = torch.load(file_path, weights_only=True) new_state_dict = {} current_size = 0 @@ -87,19 +87,19 @@ def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): base_files = glob.glob(base_pattern) delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") delta_files = glob.glob(delta_pattern) - delta_state_dict = torch.load(delta_files[0]) + delta_state_dict = torch.load(delta_files[0], weights_only=True) print("Applying the delta") weight_map = {} total_size = 0 for i, base_file in tqdm(enumerate(base_files)): - state_dict = torch.load(base_file) + state_dict = torch.load(base_file, weights_only=True) file_name = f"pytorch_model-{i}.bin" for name, param in state_dict.items(): if name not in delta_state_dict: for delta_file in delta_files: - delta_state_dict = torch.load(delta_file) + delta_state_dict = torch.load(delta_file, weights_only=True) gc.collect() if name in delta_state_dict: break