Skip to content

Commit 18ce42e

Browse files
authored
Add start with eval option (#84)
* Add start with eval option * Ping for training run * Drop p3.6 from CI * Turn off telemetry on CI * Bump up to v0.0.18
1 parent 0a23eba commit 18ce42e

File tree

6 files changed

+33
-9
lines changed

6 files changed

+33
-9
lines changed

.github/workflows/tests.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
strategy:
1919
fail-fast: false
2020
matrix:
21-
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
21+
python-version: [3.7, 3.8, 3.9, "3.10"]
2222
experimental: [false]
2323
steps:
2424
- uses: actions/checkout@v2
@@ -31,6 +31,9 @@ jobs:
3131
cache-dependency-path: 'requirements*'
3232
- name: check OS
3333
run: cat /etc/os-release
34+
- name: Telemetry off
35+
run: |
36+
export TRAINER_TELEMETRY=0
3437
- name: Install dependencies
3538
run: |
3639
sudo apt-get update

trainer/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v0.0.17
1+
v0.0.18

trainer/analytics.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import os
2+
3+
import requests
4+
5+
telemetry = os.environ.get("TRAINER_TELEMETRY")
6+
7+
8+
def ping_training_run():
9+
if telemetry == "0":
10+
return
11+
URL = "https://coqui.gateway.scarf.sh/trainer/training_run"
12+
_ = requests.get(URL)

trainer/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Callable
1+
from typing import Callable, Dict
22

33

44
class TrainerCallback:

trainer/trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.nn.parallel import DistributedDataParallel as DDP_th
2020
from torch.utils.data import DataLoader
2121

22+
from trainer.analytics import ping_training_run
2223
from trainer.callbacks import TrainerCallback
2324
from trainer.generic_utils import (
2425
KeepAverage,
@@ -241,6 +242,10 @@ class TrainerArgs(Coqpit):
241242
default=False,
242243
metadata={"help": "Skip training and only run evaluation and test."},
243244
)
245+
start_with_eval: bool = field(
246+
default=False,
247+
metadata={"help": "Start with evaluation and test."},
248+
)
244249
small_run: int = field(
245250
default=None,
246251
metadata={
@@ -388,6 +393,7 @@ def __init__( # pylint: disable=dangerous-default-value
388393
self.grad_accum_steps = args.grad_accum_steps
389394
self.overfit_batch = args.overfit_batch
390395
self.skip_train_epoch = args.skip_train_epoch
396+
self.start_with_eval = args.start_with_eval
391397

392398
assert self.grad_accum_steps > 0, " [!] grad_accum_steps must be greater than 0."
393399

@@ -519,6 +525,7 @@ def __init__( # pylint: disable=dangerous-default-value
519525
self.callbacks.on_init_end(self)
520526
self.dashboard_logger.add_config(config)
521527
self.save_training_script()
528+
ping_training_run()
522529

523530
def save_training_script(self):
524531
"""Save the training script to tracking dashboard and output path."""
@@ -1519,7 +1526,7 @@ def _fit(self) -> None:
15191526
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
15201527
self.epochs_done = epoch
15211528
self.c_logger.print_epoch_start(epoch, self.config.epochs, self.output_path)
1522-
if not self.skip_train_epoch:
1529+
if not self.skip_train_epoch and not self.start_with_eval:
15231530
self.train_epoch()
15241531
if self.config.run_eval:
15251532
self.eval_epoch()
@@ -1532,6 +1539,7 @@ def _fit(self) -> None:
15321539
if self.args.rank in [None, 0]:
15331540
self.save_best_model()
15341541
self.callbacks.on_epoch_end(self)
1542+
self.start_with_eval = False
15351543

15361544
def fit_with_largest_batch_size(self, starting_batch_size=2048) -> None:
15371545
cuda_meminfo()
@@ -1552,7 +1560,7 @@ def fit_with_largest_batch_size(self, starting_batch_size=2048) -> None:
15521560
torch.cuda.empty_cache()
15531561
else:
15541562
raise
1555-
except Exception as exception: #pylint: disable=broad-except
1563+
except Exception as exception: # pylint: disable=broad-except
15561564
# catches the torch.cuda.OutOfMemoryError
15571565
if bs > 1 and should_reduce_batch_size(exception):
15581566
bs //= 2

trainer/utils/cpu_memory.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ def get_available_cpu_memory():
88
available_memory = psutil.virtual_memory().available
99

1010
try:
11-
import resource # pylint: disable=import-outside-toplevel
12-
_, hard_mem_limit = resource.getrlimit(resource.RLIMIT_AS) #pylint: disable=unused-variable
11+
import resource # pylint: disable=import-outside-toplevel
12+
13+
_, hard_mem_limit = resource.getrlimit(resource.RLIMIT_AS) # pylint: disable=unused-variable
1314
if hard_mem_limit != resource.RLIM_INFINITY:
1415
used_memory = this_process.memory_info().vms
1516
available_memory = min(hard_mem_limit - used_memory, available_memory)
@@ -21,9 +22,9 @@ def get_available_cpu_memory():
2122

2223
def set_cpu_memory_limit(num_gigabytes):
2324
try:
24-
import resource # pylint: disable=import-outside-toplevel
25+
import resource # pylint: disable=import-outside-toplevel
2526

26-
num_bytes = int(num_gigabytes * 2 ** 30)
27+
num_bytes = int(num_gigabytes * 2**30)
2728
_, hard_limit = resource.getrlimit(resource.RLIMIT_AS)
2829
if hard_limit != resource.RLIM_INFINITY:
2930
hard_limit = min(num_bytes, hard_limit)

0 commit comments

Comments
 (0)