Skip to content

Commit 3336b0a

Browse files
Diego Urgellfacebook-github-bot
Diego Urgell
authored andcommitted
Add CheckpointPath abstraction in utils/checkpoint.py
Differential Revision: D56260188
1 parent 0159a07 commit 3336b0a

File tree

3 files changed

+324
-0
lines changed

3 files changed

+324
-0
lines changed

tests/utils/test_checkpoint.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
import unittest
9+
10+
from torchtnt.utils.checkpoint import CheckpointPath, MetricData
11+
12+
13+
class CheckpointPathTest(unittest.TestCase):
14+
def test_from_str(self) -> None:
15+
# invalid paths
16+
malformed_paths = [
17+
"foo/step_20",
18+
"foo/epoch_50",
19+
"epoch_30",
20+
"foo/epoch_20_step",
21+
"foo/epoch_20_step_30_val_loss=1a",
22+
"foo/epoch_2_step_15_mean=hello",
23+
"foo/epoch_2.6_step_23",
24+
]
25+
for path in malformed_paths:
26+
with self.assertRaisesRegex(
27+
ValueError, f"Attempted to parse malformed checkpoint path: {path}"
28+
):
29+
CheckpointPath.from_str(path)
30+
31+
# valid paths
32+
valid_paths = [
33+
("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)),
34+
(
35+
"foo/epoch_14_step_3_mean=15.0",
36+
CheckpointPath(
37+
"foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0)
38+
),
39+
),
40+
(
41+
"foo/epoch_14_step_3_loss=-27.35",
42+
CheckpointPath(
43+
"foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
44+
),
45+
),
46+
(
47+
"/foo/epoch_14_step_3_loss=-27.35",
48+
CheckpointPath(
49+
"/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35)
50+
),
51+
),
52+
(
53+
"foo/bar/epoch_23_step_31_mean_loss_squared=0.0",
54+
CheckpointPath(
55+
"foo/bar/",
56+
epoch=23,
57+
step=31,
58+
metric_data=MetricData("mean_loss_squared", 0.0),
59+
),
60+
),
61+
(
62+
"oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98",
63+
CheckpointPath(
64+
"oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61",
65+
epoch=2,
66+
step=1,
67+
metric_data=MetricData("acc", 0.98),
68+
),
69+
),
70+
]
71+
for path, expected_ckpt in valid_paths:
72+
parsed_ckpt = CheckpointPath.from_str(path)
73+
self.assertEqual(parsed_ckpt, expected_ckpt)
74+
self.assertEqual(parsed_ckpt.path, path)
75+
76+
# with a trailing slash
77+
ckpt = CheckpointPath.from_str("foo/epoch_0_step_1/")
78+
self.assertEqual(ckpt, CheckpointPath("foo", epoch=0, step=1))
79+
self.assertEqual(ckpt.path, "foo/epoch_0_step_1")
80+
81+
def test_compare_by_recency(self) -> None:
82+
old = CheckpointPath("foo", epoch=0, step=1)
83+
new = CheckpointPath("foo", epoch=1, step=1)
84+
self.assertTrue(new.newer_than(old))
85+
self.assertFalse(old.newer_than(new))
86+
self.assertFalse(new == old)
87+
88+
old = CheckpointPath("foo", epoch=3, step=5)
89+
new = CheckpointPath("foo", epoch=3, step=9)
90+
self.assertTrue(new.newer_than(old))
91+
self.assertFalse(old.newer_than(new))
92+
self.assertFalse(new == old)
93+
94+
twin1 = CheckpointPath(
95+
"foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0)
96+
)
97+
almost_twin = CheckpointPath(
98+
"foo", epoch=2, step=5, metric_data=MetricData("bar", 2.0)
99+
)
100+
101+
self.assertFalse(twin1.newer_than(almost_twin))
102+
self.assertFalse(almost_twin.newer_than(twin1))
103+
self.assertFalse(twin1 == almost_twin)
104+
105+
twin2 = CheckpointPath(
106+
"foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0)
107+
)
108+
self.assertTrue(twin1 == twin2)
109+
110+
def test_compare_by_optimality(self) -> None:
111+
# not both metric aware
112+
ckpt1 = CheckpointPath("foo", epoch=0, step=1)
113+
ckpt2 = CheckpointPath("foo", epoch=1, step=1)
114+
ckpt3 = CheckpointPath(
115+
"foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0)
116+
)
117+
for ckpt in [ckpt2, ckpt3]:
118+
with self.assertRaisesRegex(
119+
AssertionError,
120+
"Attempted to compare optimality of non metric-aware checkpoints",
121+
):
122+
ckpt1.more_optimal_than(ckpt, mode="min")
123+
124+
# tracking different metrics
125+
ckpt4 = CheckpointPath(
126+
"foo", epoch=1, step=1, metric_data=MetricData("baz", 1.0)
127+
)
128+
with self.assertRaisesRegex(
129+
AssertionError,
130+
"Attempted to compare optimality of checkpoints tracking different metrics",
131+
):
132+
ckpt3.more_optimal_than(ckpt4, mode="min")
133+
134+
smaller = CheckpointPath(
135+
"foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0)
136+
)
137+
larger = CheckpointPath(
138+
"foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0)
139+
)
140+
self.assertTrue(larger.more_optimal_than(smaller, mode="max"))
141+
self.assertFalse(smaller.more_optimal_than(larger, mode="max"))
142+
self.assertTrue(smaller.more_optimal_than(larger, mode="min"))
143+
self.assertFalse(larger.more_optimal_than(smaller, mode="min"))

torchtnt/utils/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from .checkpoint import CheckpointPath, MetricData
910
from .device import (
1011
copy_data_to_device,
1112
CPUStats,
@@ -148,4 +149,6 @@
148149
"is_windows",
149150
"get_pet_launch_config",
150151
"spawn_multi_process",
152+
"CheckpointPath",
153+
"MetricData",
151154
]

torchtnt/utils/checkpoint.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
import os
9+
import re
10+
from dataclasses import dataclass
11+
from functools import total_ordering
12+
from typing import Literal, Optional, Pattern
13+
14+
from pyre_extensions import none_throws
15+
16+
17+
@dataclass
18+
class MetricData:
19+
"""
20+
Representation of a metric instance. Should provide both a metric name and it's value.
21+
"""
22+
23+
name: str
24+
value: float
25+
26+
27+
@total_ordering
28+
class CheckpointPath:
29+
"""
30+
Representation of a checkpoint path. Handles parsing and serialization of the specific path format.
31+
Currently, the basic compliant path format is: <dirpath>/epoch_<epoch>_step_<step>
32+
If a metric is being tracked, it's added to the name: <dirpath>/epoch_<epoch>_step_<step>_<metric_name>=<metric_value>
33+
34+
This class is well-ordered by checkpoint recency, so any comparisons will operate using the epoch + step. Sorting by
35+
metric can be done by extracting the metric value from the metric_data attribute.
36+
"""
37+
38+
PATH_REGEX: Pattern = re.compile(
39+
r"^(.+)epoch_(\d+)_step_(\d+)(?:_(.+)=(-?\d+\.?\d*))?\/?$"
40+
)
41+
42+
def __init__(
43+
self,
44+
dirpath: str,
45+
epoch: int,
46+
step: int,
47+
metric_data: Optional[MetricData] = None,
48+
) -> None:
49+
"""
50+
Args:
51+
dirpath: The base directory path that checkpoints are saved in.
52+
epoch: The epoch number of this checkpoint.
53+
step: The step number of this checkpoint.
54+
metric_data: Optional data about the metric being tracked. Should contain both metric name and value.
55+
"""
56+
self.dirpath: str = dirpath.rstrip("/")
57+
self.epoch = epoch
58+
self.step = step
59+
self.metric_data = metric_data
60+
61+
@classmethod
62+
def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
63+
"""
64+
Given a directory path, try to parse it and extract the checkpoint data.
65+
The expected format is: <dirpath>/epoch_<epoch>_step_<step>_<metric_name>=<metric_value>,
66+
where the metric name and value are optional.
67+
68+
Args:
69+
checkpoint_path: The path to the checkpoint directory.
70+
71+
Returns:
72+
A CheckpointPath instance if the path is valid, otherwise None.
73+
74+
Raises:
75+
ValueError: If the path is malformed and can't be parsed.
76+
"""
77+
path_match = cls.PATH_REGEX.match(checkpoint_path)
78+
if not path_match:
79+
raise ValueError(
80+
f"Attempted to parse malformed checkpoint path: {checkpoint_path}."
81+
)
82+
83+
dirpath, epoch, step, metric_name, metric_value = path_match.groups()
84+
try:
85+
metric_data: Optional[MetricData] = None
86+
if metric_name:
87+
metric_value_f = float(metric_value)
88+
metric_data = MetricData(name=metric_name, value=metric_value_f)
89+
90+
return CheckpointPath(
91+
dirpath=dirpath,
92+
epoch=int(epoch),
93+
step=int(step),
94+
metric_data=metric_data,
95+
)
96+
97+
except ValueError:
98+
# Should never happen since path matches regex
99+
raise ValueError(
100+
f"Invalid data types found in checkpoint path: {checkpoint_path}."
101+
)
102+
103+
@property
104+
def path(self) -> str:
105+
"""
106+
Returns:
107+
The full path to the checkpoint directory.
108+
"""
109+
name = f"epoch_{self.epoch}_step_{self.step}"
110+
if self.metric_data:
111+
name += f"_{self.metric_data.name}={self.metric_data.value}"
112+
113+
return os.path.join(self.dirpath, name)
114+
115+
def newer_than(self, other: "CheckpointPath") -> bool:
116+
"""
117+
Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other.
118+
119+
Returns:
120+
True if this checkpoint is newer than the other, otherwise False.
121+
"""
122+
if self.epoch != other.epoch:
123+
return self.epoch > other.epoch
124+
125+
return self.step > other.step
126+
127+
def more_optimal_than(
128+
self, other: "CheckpointPath", mode: Literal["min", "max"]
129+
) -> bool:
130+
"""
131+
Given another CheckpointPath instance, determine if this checkpoint is strictly more optimal than the other.
132+
Optimality is determined by comparing the metric value of the two checkpoints. The mode indicates if the
133+
metric value should be minimized or maximized. This only works for metric-aware checkpoints.
134+
135+
Args:
136+
other: The other checkpoint path to compare against.
137+
mode: The mode to use for comparison.
138+
139+
Returns:
140+
True if this checkpoint is more optimal than the other, otherwise False.
141+
142+
Note: This expects that both checkpoints are metric-aware, and that they are tracking the same metric.
143+
"""
144+
145+
assert (
146+
self.metric_data and other.metric_data
147+
), f"Attempted to compare optimality of non metric-aware checkpoints: {self} and {other}"
148+
149+
assert (
150+
self.metric_data.name == other.metric_data.name
151+
), f"Attempted to compare optimality of checkpoints tracking different metrics: {self} and {other}"
152+
153+
if mode == "min":
154+
return (
155+
none_throws(self.metric_data).value
156+
< none_throws(other.metric_data).value
157+
)
158+
159+
return (
160+
none_throws(self.metric_data).value > none_throws(other.metric_data).value
161+
)
162+
163+
def __str__(self) -> str:
164+
return self.path
165+
166+
def __repr__(self) -> str:
167+
return f"CheckpointPath(dirpath={self.dirpath}, epoch={self.epoch}, step={self.step}, metric_data={self.metric_data})"
168+
169+
def __eq__(self, other: "CheckpointPath") -> bool:
170+
return (
171+
self.dirpath == other.dirpath
172+
and self.epoch == other.epoch
173+
and self.step == other.step
174+
and self.metric_data == other.metric_data
175+
)
176+
177+
def __gt__(self, other: "CheckpointPath") -> bool:
178+
return self.newer_than(other)

0 commit comments

Comments
 (0)