Skip to content

Commit 5ce7ff5

Browse files
authored
Implement setup_training_environment (#36)
* Implement ```setup_training_environment``` * Make style
1 parent 9bbe84d commit 5ce7ff5

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

trainer/trainer.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,6 @@
4747

4848
logger = logging.getLogger("trainer")
4949

50-
if platform.system() != "Windows":
51-
multiprocessing.set_start_method("fork")
52-
# https://github.com/pytorch/pytorch/issues/973
53-
import resource
54-
55-
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
56-
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
57-
58-
5950
if is_apex_available():
6051
from apex import amp
6152

@@ -391,15 +382,8 @@ def __init__( # pylint: disable=dangerous-default-value
391382
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
392383
self._setup_logger_config(log_file)
393384

394-
# set and initialize Pytorch runtime
395-
self.use_cuda, self.num_gpus = setup_torch_training_env(
396-
cudnn_enable=config.cudnn_enable,
397-
cudnn_deterministic=config.cudnn_deterministic,
398-
cudnn_benchmark=config.cudnn_benchmark,
399-
use_ddp=args.use_ddp,
400-
training_seed=config.training_seed,
401-
gpu=gpu if args.gpu is None else args.gpu,
402-
)
385+
# setup training environment
386+
self.use_cuda, self.num_gpus = self.setup_training_environment(args=args, config=config, gpu=gpu)
403387

404388
# init loggers
405389
self.dashboard_logger, self.c_logger = self.init_loggers(
@@ -600,6 +584,27 @@ def init_training(
600584
new_fields["github_branch"] = get_git_branch()
601585
return config, new_fields
602586

587+
@staticmethod
588+
def setup_training_environment(args, config, gpu):
589+
if platform.system() != "Windows":
590+
multiprocessing.set_start_method("fork")
591+
# https://github.com/pytorch/pytorch/issues/973
592+
import resource # pylint: disable=import-outside-toplevel
593+
594+
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
595+
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
596+
597+
# set and initialize Pytorch runtime
598+
use_cuda, num_gpus = setup_torch_training_env(
599+
cudnn_enable=config.cudnn_enable,
600+
cudnn_deterministic=config.cudnn_deterministic,
601+
cudnn_benchmark=config.cudnn_benchmark,
602+
use_ddp=args.use_ddp,
603+
training_seed=config.training_seed,
604+
gpu=gpu if args.gpu is None else args.gpu,
605+
)
606+
return use_cuda, num_gpus
607+
603608
@staticmethod
604609
def run_get_model(config: Coqpit, get_model: Callable) -> nn.Module:
605610
"""Run the `get_model` function and return the model.

trainer/trainer_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def setup_torch_training_env(
5252
Returns:
5353
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
5454
"""
55-
5655
# clear cache before training
5756
torch.cuda.empty_cache()
5857

0 commit comments

Comments
 (0)