Skip to content

Commit 36234e0

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Move BestCheckpointConfig to utils/checkpoint.py
Differential Revision: D56455020
1 parent 04dc440 commit 36234e0

File tree

7 files changed

+25
-33
lines changed

7 files changed

+25
-33
lines changed

tests/framework/callbacks/test_base_checkpointer.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,14 @@
3131
from torchtnt.framework.callbacks.base_checkpointer import (
3232
BaseCheckpointer as BaseCheckpointer,
3333
)
34-
from torchtnt.framework.callbacks.checkpointer_types import (
35-
BestCheckpointConfig,
36-
RestoreOptions,
37-
)
34+
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
3835
from torchtnt.framework.callbacks.lambda_callback import Lambda
3936
from torchtnt.framework.fit import fit
4037
from torchtnt.framework.state import State
4138

4239
from torchtnt.framework.train import train
4340
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData
41+
from torchtnt.utils.checkpoint import BestCheckpointConfig
4442
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
4543
from torchtnt.utils.env import init_from_env
4644
from torchtnt.utils.test_utils import skip_if_not_distributed

torchtnt/framework/callbacks/base_checkpointer.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616
import torch.distributed as dist
1717
from pyre_extensions import none_throws
1818
from torchtnt.framework.callback import Callback
19-
from torchtnt.framework.callbacks.checkpointer_types import (
20-
BestCheckpointConfig,
21-
RestoreOptions,
22-
)
19+
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
2320
from torchtnt.framework.state import EntryPoint, State
2421
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
2522
from torchtnt.framework.utils import get_timing_context
@@ -28,6 +25,7 @@
2825
_metadata_exists,
2926
_sort_by_metric_value,
3027
_sort_by_recency,
28+
BestCheckpointConfig,
3129
get_best_checkpoint_path,
3230
get_checkpoint_dirpaths,
3331
get_latest_checkpoint_path,

torchtnt/framework/callbacks/checkpointer_types.py

+1-15
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
from dataclasses import dataclass
10-
from typing import Literal, Optional
10+
from typing import Optional
1111

1212

1313
# TODO: eventually support overriding all knobs
@@ -39,17 +39,3 @@ class RestoreOptions:
3939
restore_eval_progress: bool = True
4040
restore_optimizers: bool = True
4141
restore_lr_schedulers: bool = True
42-
43-
44-
@dataclass
45-
class BestCheckpointConfig:
46-
"""
47-
Config for saving the best checkpoints.
48-
49-
Args:
50-
monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit.
51-
mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric.
52-
"""
53-
54-
monitored_metric: str
55-
mode: Literal["min", "max"] = "min"

torchtnt/framework/callbacks/dcp_saver.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
)
2424

2525
from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
26-
from torchtnt.framework.callbacks.checkpointer_types import (
27-
BestCheckpointConfig,
28-
KnobOptions,
29-
RestoreOptions,
30-
)
26+
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
3127
from torchtnt.framework.state import State
3228
from torchtnt.framework.unit import (
3329
AppStateMixin,
@@ -37,6 +33,7 @@
3733
TTrainUnit,
3834
)
3935
from torchtnt.framework.utils import get_timing_context
36+
from torchtnt.utils.checkpoint import BestCheckpointConfig
4037
from torchtnt.utils.optimizer import init_optim_state
4138
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
4239
from torchtnt.utils.stateful import MultiStateful, Stateful

torchtnt/framework/callbacks/torchsnapshot_saver.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@
2222
)
2323

2424
from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
25-
from torchtnt.framework.callbacks.checkpointer_types import (
26-
BestCheckpointConfig,
27-
KnobOptions,
28-
RestoreOptions,
29-
)
25+
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
3026
from torchtnt.framework.state import State
3127
from torchtnt.framework.unit import (
3228
AppStateMixin,
@@ -36,6 +32,7 @@
3632
TTrainUnit,
3733
)
3834
from torchtnt.framework.utils import get_timing_context
35+
from torchtnt.utils.checkpoint import BestCheckpointConfig
3936
from torchtnt.utils.optimizer import init_optim_state
4037
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
4138
from torchtnt.utils.stateful import Stateful

torchtnt/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
from .checkpoint import (
10+
BestCheckpointConfig,
1011
CheckpointPath,
1112
get_best_checkpoint_path,
1213
get_checkpoint_dirpaths,
@@ -160,4 +161,5 @@
160161
"get_best_checkpoint_path",
161162
"get_checkpoint_dirpaths",
162163
"get_latest_checkpoint_path",
164+
"BestCheckpointConfig",
163165
]

torchtnt/utils/checkpoint.py

+14
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ class MetricData:
3131
value: float
3232

3333

34+
@dataclass
35+
class BestCheckpointConfig:
36+
"""
37+
Config for saving the best checkpoints.
38+
39+
Args:
40+
monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit.
41+
mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric.
42+
"""
43+
44+
monitored_metric: str
45+
mode: Literal["min", "max"] = "min"
46+
47+
3448
@total_ordering
3549
class CheckpointPath:
3650
"""

0 commit comments

Comments
 (0)