Skip to content

Commit 61d186a

Browse files
authored
feat: efficiency improvements (#25)
1 parent 9a62fe2 commit 61d186a

File tree

5 files changed

+23
-8
lines changed

5 files changed

+23
-8
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,6 @@ dmypy.json
133133

134134
# Pyre type checker
135135
.pyre/
136+
137+
# MacOS files
138+
*.DS_Store

src/mblm/train/core/trainer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,15 @@
5454
TTrainConfig,
5555
)
5656
from mblm.train.core.iter import epoch_cycler
57-
from mblm.utils.cuda import cuda_memory_snapshot, cuda_properties
57+
from mblm.utils.cuda import IS_BF16_AVAILABLE, cuda_memory_snapshot, cuda_properties
5858
from mblm.utils.distributed import ElasticRunVars
59-
from mblm.utils.io import CSVWriter, StateDict, dump_yml, load_model_state, save_model_state
59+
from mblm.utils.io import (
60+
CSVWriter,
61+
StateDict,
62+
dump_yml,
63+
load_model_state,
64+
save_model_state,
65+
)
6066
from mblm.utils.logging import create_logger
6167
from mblm.utils.misc import retry
6268
from mblm.utils.top_n import TopN
@@ -76,7 +82,7 @@ class CoreTrainerOptions:
7682
train_prog_min_interval_seconds: int = 1
7783
valid_prog_min_interval_seconds: int = 1
7884
track_first_fw_bw_exec_times: int | None = 30 # for 30 first passes, track fw/bw time
79-
amp_dtype: torch.dtype = torch.half # may use bfloat16
85+
amp_dtype: torch.dtype = torch.bfloat16 if IS_BF16_AVAILABLE else torch.half
8086

8187

8288
class CoreTrainer(ABC, Generic[TModel, TBatch, TModelParams, TTrainConfig, TIoConfig]):
@@ -159,7 +165,9 @@ def __init__(
159165
)
160166

161167
assert config.io.validate_amount > 0, "Validate amount must be strictly positive"
162-
assert config.io.num_models_to_save > 0, "Must save at least 1 model"
168+
assert config.io.num_models_to_save >= 0, "num_models_to_save cant be negative"
169+
if config.io.num_models_to_save == 0:
170+
self._log.warning("No model of this training will be saved!")
163171

164172
if config.io.validate_amount < config.io.num_models_to_save:
165173
self._log.warning(
@@ -963,7 +971,7 @@ def before_new_epoch(epoch: int) -> None:
963971

964972
best_model = self._unpack_distributed_model(self._model_dist)
965973

966-
if self._is_main_worker:
974+
if self._is_main_worker and self.config.io.num_models_to_save > 0:
967975
# if, on the main worker, populate the model with the best state
968976
# non-main workers will simply return the latest model, which won't
969977
# be used anyway because testing happens only on the main worker
@@ -1003,4 +1011,3 @@ def test(
10031011
avg_grad_clipped=-1,
10041012
)
10051013
self._log.info("Finished testing")
1006-
return None

src/mblm/utils/cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
import torch.version
2929
from torch.types import Device
3030

31+
IS_CUDA_AVAILABLE = torch.cuda.is_available()
32+
IS_BF16_AVAILABLE = IS_CUDA_AVAILABLE and torch.cuda.is_bf16_supported()
33+
3134

3235
@dataclass
3336
class CudaProperties:

src/mblm/utils/top_n.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TopN(Generic[_T]):
4040
with the first element being the smallest. A max heap can be specified via `top_largest`.
4141
4242
Args:
43-
n (int): Max number of items to store
43+
n (int): Max number of items to store. If zero this class is a no-op
4444
deep_copy (bool = `False`): Create a deep copy of elements
4545
top_largest (bool = `False`): If true, store the `n` largest items instead of
4646
the smallest items
@@ -77,6 +77,8 @@ def add(self, item: tuple[SupportsFloat, _T]) -> None:
7777
Add an item to the queue. The first tuple entry is used to
7878
determine the position of the newly added element
7979
"""
80+
if self._max_heap_items == 0:
81+
return
8082
val, data = item
8183
if self._deep_copy:
8284
data = copy.deepcopy(data)

tests/e2e/trainer/sample-config-grad-acc-1.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
io:
22
name_model: my-model
33
output_dir: tests/e2e/trainer/outputs # static
4-
num_models_to_save: 2
4+
num_models_to_save: 0
55
validate_amount: 10
66
log_train_loss_amount: 20
77
description: >-

0 commit comments

Comments
 (0)