@@ -91,13 +91,12 @@ def objective(trial):
9191
9292 trial_number = RetryFailedTrialCallback .retried_trial_number (trial )
9393
94- if trial_number is not None :
95- study = trial .study
96- artifact_id = study .trials [trial_number ].user_attrs ["artifact_id" ]
94+ artifact_id = trial .study .trials [trial_number ].user_attrs .get ("artifact_id" )
95+ if trial_number is not None and artifact_id is not None :
9796 download_artifact (
98- artifact_store = artifact_store , file_path = "./tmp_model .pt" , artifact_id = artifact_id
97+ artifact_store = artifact_store , file_path = f "./tmp_model_ { trial_number } .pt" , artifact_id = artifact_id
9998 )
100- checkpoint = torch .load ("./tmp_model .pt" )
99+ checkpoint = torch .load (f "./tmp_model_ { trial_number } .pt" )
101100 epoch = checkpoint ["epoch" ]
102101 epoch_begin = epoch + 1
103102
@@ -158,15 +157,15 @@ def objective(trial):
158157 "optimizer_state_dict" : optimizer .state_dict (),
159158 "accuracy" : accuracy ,
160159 },
161- "./tmp_model .pt" ,
160+ f "./tmp_model_ { trial_number } .pt" ,
162161 )
163162 artifact_id = upload_artifact (
164163 artifact_store = artifact_store ,
165- file_path = "./tmp_model .pt" ,
164+ file_path = f "./tmp_model_ { trial_number } .pt" ,
166165 study_or_trial = trial ,
167166 )
168167 trial .set_user_attr ("artifact_id" , artifact_id )
169- os .remove ("./tmp_model .pt" )
168+ os .remove (f "./tmp_model_ { trial_number } .pt" )
170169
171170 # Handle pruning based on the intermediate value.
172171 if trial .should_prune ():
0 commit comments