@@ -89,20 +89,25 @@ def objective(trial):
8989 lr = trial .suggest_float ("lr" , 1e-5 , 1e-1 , log = True )
9090 optimizer = getattr (optim , optimizer_name )(model .parameters (), lr = lr )
9191
92- trial_number = RetryFailedTrialCallback .retried_trial_number (trial )
93-
94- artifact_id = trial_number and trial .study .trials [trial_number ].user_attrs .get ("artifact_id" )
95- if trial_number is not None and artifact_id is not None :
92+ artifact_id = None
93+ retry_history = RetryFailedTrialCallback .retry_history (trial )
94+ for trial_number in reversed (retry_history ):
95+ artifact_id = trial .study .trials [trial_number ].user_attrs .get ("artifact_id" )
96+ if artifact_id is not None :
97+ retry_trial_number = trial_number
98+ break
99+
100+ if artifact_id is not None :
96101 download_artifact (
97102 artifact_store = artifact_store ,
98- file_path = f"./tmp_model_{ trial_number } .pt" ,
103+ file_path = f"./tmp_model_{ trial . number } .pt" ,
99104 artifact_id = artifact_id ,
100105 )
101- checkpoint = torch .load (f"./tmp_model_{ trial_number } .pt" )
106+ checkpoint = torch .load (f"./tmp_model_{ trial . number } .pt" )
102107 epoch = checkpoint ["epoch" ]
103108 epoch_begin = epoch + 1
104109
105- print (f"Loading a checkpoint from trial { trial_number } in epoch { epoch } ." )
110+ print (f"Loading a checkpoint from trial { retry_trial_number } in epoch { epoch } ." )
106111
107112 model .load_state_dict (checkpoint ["model_state_dict" ])
108113 optimizer .load_state_dict (checkpoint ["optimizer_state_dict" ])
@@ -159,15 +164,15 @@ def objective(trial):
159164 "optimizer_state_dict" : optimizer .state_dict (),
160165 "accuracy" : accuracy ,
161166 },
162- f"./tmp_model_{ trial_number } .pt" ,
167+ f"./tmp_model_{ trial . number } .pt" ,
163168 )
164169 artifact_id = upload_artifact (
165170 artifact_store = artifact_store ,
166- file_path = f"./tmp_model_{ trial_number } .pt" ,
171+ file_path = f"./tmp_model_{ trial . number } .pt" ,
167172 study_or_trial = trial ,
168173 )
169174 trial .set_user_attr ("artifact_id" , artifact_id )
170- os .remove (f"./tmp_model_{ trial_number } .pt" )
175+ os .remove (f"./tmp_model_{ trial . number } .pt" )
171176
172177 # Handle pruning based on the intermediate value.
173178 if trial .should_prune ():
0 commit comments