diff --git a/ray_lightning/launchers/ray_launcher.py b/ray_lightning/launchers/ray_launcher.py index 8802e59..a452162 100644 --- a/ray_lightning/launchers/ray_launcher.py +++ b/ray_lightning/launchers/ray_launcher.py @@ -298,16 +298,20 @@ def _wrapping_function( trainer.strategy.local_rank = self._strategy.local_rank set_cuda_device_if_used(trainer.strategy) + # Set operations to deterministic in this worker when required + if trainer._accelerator_connector.deterministic: + trainer._accelerator_connector._init_deterministic(True) + results = function(*args, **kwargs) if trainer is not None: - return self._collect_rank_zero_results(trainer, results) - else: - return None + results = self._collect_rank_zero_results(trainer, results) + + if results is None: + trainer._teardown() + trainer._call_teardown_hook() - trainer._teardown() - trainer._call_teardown_hook() - return None + return results def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_RayOutput"]: @@ -316,18 +320,20 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", This function is run on the worker process. """ rank_zero_debug("Finalizing the Ray launcher environment.") + if trainer.strategy.global_rank != 0: + return None + + if trainer.strategy.local_rank != 0: + return None + checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path \ if checkpoint_callback else None state_dict = trainer.lightning_module.state_dict() - if self._strategy.global_rank != 0: - return None - # Move state_dict to cpu before converting it to model state stream - if trainer.strategy.local_rank == 0: - state_dict = move_data_to_device(state_dict, "cpu") + state_dict = move_data_to_device(state_dict, "cpu") # PyTorch Lightning saves the model weights in a temp file and # loads it back on the driver.