|
47 | 47 |
|
48 | 48 | logger = logging.getLogger("trainer") |
49 | 49 |
|
50 | | -if platform.system() != "Windows": |
51 | | - multiprocessing.set_start_method("fork") |
52 | | - # https://github.com/pytorch/pytorch/issues/973 |
53 | | - import resource |
54 | | - |
55 | | - rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
56 | | - resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) |
57 | | - |
58 | | - |
59 | 50 | if is_apex_available(): |
60 | 51 | from apex import amp |
61 | 52 |
|
@@ -391,15 +382,8 @@ def __init__( # pylint: disable=dangerous-default-value |
391 | 382 | log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt") |
392 | 383 | self._setup_logger_config(log_file) |
393 | 384 |
|
394 | | - # set and initialize Pytorch runtime |
395 | | - self.use_cuda, self.num_gpus = setup_torch_training_env( |
396 | | - cudnn_enable=config.cudnn_enable, |
397 | | - cudnn_deterministic=config.cudnn_deterministic, |
398 | | - cudnn_benchmark=config.cudnn_benchmark, |
399 | | - use_ddp=args.use_ddp, |
400 | | - training_seed=config.training_seed, |
401 | | - gpu=gpu if args.gpu is None else args.gpu, |
402 | | - ) |
| 385 | + # setup training environment |
| 386 | + self.use_cuda, self.num_gpus = self.setup_training_environment(args=args, config=config, gpu=gpu) |
403 | 387 |
|
404 | 388 | # init loggers |
405 | 389 | self.dashboard_logger, self.c_logger = self.init_loggers( |
@@ -600,6 +584,27 @@ def init_training( |
600 | 584 | new_fields["github_branch"] = get_git_branch() |
601 | 585 | return config, new_fields |
602 | 586 |
|
| 587 | + @staticmethod |
| 588 | + def setup_training_environment(args, config, gpu): |
| 589 | + if platform.system() != "Windows": |
| 590 | + multiprocessing.set_start_method("fork") |
| 591 | + # https://github.com/pytorch/pytorch/issues/973 |
| 592 | + import resource # pylint: disable=import-outside-toplevel |
| 593 | + |
| 594 | + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
| 595 | + resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) |
| 596 | + |
| 597 | + # set and initialize Pytorch runtime |
| 598 | + use_cuda, num_gpus = setup_torch_training_env( |
| 599 | + cudnn_enable=config.cudnn_enable, |
| 600 | + cudnn_deterministic=config.cudnn_deterministic, |
| 601 | + cudnn_benchmark=config.cudnn_benchmark, |
| 602 | + use_ddp=args.use_ddp, |
| 603 | + training_seed=config.training_seed, |
| 604 | + gpu=gpu if args.gpu is None else args.gpu, |
| 605 | + ) |
| 606 | + return use_cuda, num_gpus |
| 607 | + |
603 | 608 | @staticmethod |
604 | 609 | def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module: |
605 | 610 | """Run the `get_model` function and return the model. |
|
0 commit comments