diff --git a/allenact/algorithms/onpolicy_sync/engine.py b/allenact/algorithms/onpolicy_sync/engine.py index d032d190..3f98a913 100644 --- a/allenact/algorithms/onpolicy_sync/engine.py +++ b/allenact/algorithms/onpolicy_sync/engine.py @@ -46,6 +46,7 @@ from allenact.algorithms.onpolicy_sync.vector_sampled_tasks import ( COMPLETE_TASK_CALLBACK_KEY, COMPLETE_TASK_METRICS_KEY, + COMPLETE_TASK_TIMEOUT_CORRECTION_KEY, SingleProcessVectorSampledTasks, VectorSampledTasks, ) @@ -681,7 +682,7 @@ def collect_step_across_all_task_samplers( ) # Save after task completion metrics - for step_result in outputs: + for index, step_result in enumerate(outputs): if step_result.info is not None: if COMPLETE_TASK_METRICS_KEY in step_result.info: self.single_process_metrics.append( @@ -693,6 +694,10 @@ def collect_step_across_all_task_samplers( step_result.info[COMPLETE_TASK_CALLBACK_KEY] ) del step_result.info[COMPLETE_TASK_CALLBACK_KEY] + if COMPLETE_TASK_TIMEOUT_CORRECTION_KEY in step_result.info: + flat_actions[0, index, 0] = torch.tensor( + step_result.info[COMPLETE_TASK_TIMEOUT_CORRECTION_KEY] + ) rewards: Union[List, torch.Tensor] observations, rewards, dones, infos = [list(x) for x in zip(*outputs)] @@ -1059,7 +1064,11 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin): if training: aggregate_bsize = self.distributed_weighted_sum(bsize, 1) to_track["global_batch_size"] = aggregate_bsize - to_track["lr"] = self.optimizer.param_groups[0]["lr"] + if len(self.optimizer.param_groups) >= 2: + for i, param_group in enumerate(self.optimizer.param_groups): + to_track[f"lr_group_{i}"] = param_group["lr"] + else: + to_track["lr"] = self.optimizer.param_groups[0]["lr"] if training_settings.num_mini_batch is not None: to_track["rollout_num_mini_batch"] = ( @@ -1217,9 +1226,13 @@ def __init__( " feature and we'll be happy to review it." ) + if not hasattr(self.actor_critic, "set_learning_rate_for_specific_parameters"): + params = [p for p in self.actor_critic.parameters() if p.requires_grad] + else: + params = self.actor_critic.set_learning_rate_for_specific_parameters() self.optimizer: optim.optimizer.Optimizer = ( self.training_pipeline.optimizer_builder( - params=[p for p in self.actor_critic.parameters() if p.requires_grad] + params=params ) ) diff --git a/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py b/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py index 0b17e28f..21ecc70d 100644 --- a/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py +++ b/allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py @@ -48,6 +48,7 @@ DEFAULT_MP_CONTEXT_TYPE = "forkserver" COMPLETE_TASK_METRICS_KEY = "__AFTER_TASK_METRICS__" COMPLETE_TASK_CALLBACK_KEY = "__AFTER_TASK_CALLBACK__" +COMPLETE_TASK_TIMEOUT_CORRECTION_KEY = "__AFTER_TASK_TIMEOUT_CORRECTION__" STEP_COMMAND = "step" NEXT_TASK_COMMAND = "next_task" diff --git a/allenact/main.py b/allenact/main.py index 8fcce3fe..827d5205 100755 --- a/allenact/main.py +++ b/allenact/main.py @@ -494,7 +494,7 @@ def main(): collect_valid_results=args.collect_valid_results, valid_on_initial_weights=args.valid_on_initial_weights, try_restart_after_task_error=args.enable_crash_recovery, - save_ckpt_at_every_host=save_ckpt_at_every_host, + save_ckpt_at_every_host=args.save_ckpt_at_every_host, ) else: OnPolicyRunner( diff --git a/allenact/utils/experiment_utils.py b/allenact/utils/experiment_utils.py index f123e87a..a35b585c 100644 --- a/allenact/utils/experiment_utils.py +++ b/allenact/utils/experiment_utils.py @@ -1214,9 +1214,16 @@ def download_checkpoint_from_wandb( else: assert len(ckpt_steps) == 1 step = ckpt_steps[0] - ckpt_fn = "{}-step-{}:latest".format(run_token, step) - artifact = api.artifact(ckpt_fn) - _ = artifact.download(all_ckpt_dir) - ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, step) - shutil.move("{}/ckpt.pt".format(all_ckpt_dir), ckpt_dir) + try: + ckpt_fn = "{}-step-{}:latest".format(run_token, step) + artifact = api.artifact(ckpt_fn) + _ = artifact.download(all_ckpt_dir) + ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, step) + shutil.move("{}/ckpt.pt".format(all_ckpt_dir), ckpt_dir) + except: + ckpt_fn = "{}-{}:latest".format(run_token, step) + artifact = api.artifact(ckpt_fn) + _ = artifact.download(all_ckpt_dir) + ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, step) + shutil.move("{}/model.ckpt".format(all_ckpt_dir), ckpt_dir) return ckpt_dir