|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | +import unittest |
| 9 | + |
| 10 | +from torchtnt.utils.checkpoint import CheckpointPath, MetricData |
| 11 | + |
| 12 | + |
| 13 | +class CheckpointPathTest(unittest.TestCase): |
| 14 | + def test_from_str(self) -> None: |
| 15 | + # invalid paths |
| 16 | + malformed_paths = [ |
| 17 | + "foo/step_20", |
| 18 | + "foo/epoch_50", |
| 19 | + "epoch_30", |
| 20 | + "foo/epoch_20_step", |
| 21 | + "foo/epoch_20_step_30_val_loss=1a", |
| 22 | + "foo/epoch_2_step_15_mean=hello", |
| 23 | + "foo/epoch_2.6_step_23", |
| 24 | + ] |
| 25 | + for path in malformed_paths: |
| 26 | + with self.assertRaisesRegex( |
| 27 | + ValueError, f"Attempted to parse malformed checkpoint path: {path}" |
| 28 | + ): |
| 29 | + CheckpointPath.from_str(path) |
| 30 | + |
| 31 | + # valid paths |
| 32 | + valid_paths = [ |
| 33 | + ("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)), |
| 34 | + ( |
| 35 | + "foo/epoch_14_step_3_mean=15.0", |
| 36 | + CheckpointPath( |
| 37 | + "foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0) |
| 38 | + ), |
| 39 | + ), |
| 40 | + ( |
| 41 | + "foo/epoch_14_step_3_loss=-27.35", |
| 42 | + CheckpointPath( |
| 43 | + "foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) |
| 44 | + ), |
| 45 | + ), |
| 46 | + ( |
| 47 | + "/foo/epoch_14_step_3_loss=-27.35", |
| 48 | + CheckpointPath( |
| 49 | + "/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) |
| 50 | + ), |
| 51 | + ), |
| 52 | + ( |
| 53 | + "foo/bar/epoch_23_step_31_mean_loss_squared=0.0", |
| 54 | + CheckpointPath( |
| 55 | + "foo/bar/", |
| 56 | + epoch=23, |
| 57 | + step=31, |
| 58 | + metric_data=MetricData("mean_loss_squared", 0.0), |
| 59 | + ), |
| 60 | + ), |
| 61 | + ( |
| 62 | + "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98", |
| 63 | + CheckpointPath( |
| 64 | + "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61", |
| 65 | + epoch=2, |
| 66 | + step=1, |
| 67 | + metric_data=MetricData("acc", 0.98), |
| 68 | + ), |
| 69 | + ), |
| 70 | + ] |
| 71 | + for path, expected_ckpt in valid_paths: |
| 72 | + parsed_ckpt = CheckpointPath.from_str(path) |
| 73 | + self.assertEqual(parsed_ckpt, expected_ckpt) |
| 74 | + self.assertEqual(parsed_ckpt.path, path) |
| 75 | + |
| 76 | + # with a trailing slash |
| 77 | + ckpt = CheckpointPath.from_str("foo/epoch_0_step_1/") |
| 78 | + self.assertEqual(ckpt, CheckpointPath("foo", epoch=0, step=1)) |
| 79 | + self.assertEqual(ckpt.path, "foo/epoch_0_step_1") |
| 80 | + |
| 81 | + def test_compare_by_recency(self) -> None: |
| 82 | + old = CheckpointPath("foo", epoch=0, step=1) |
| 83 | + new = CheckpointPath("foo", epoch=1, step=1) |
| 84 | + self.assertTrue(new.newer_than(old)) |
| 85 | + self.assertFalse(old.newer_than(new)) |
| 86 | + self.assertFalse(new == old) |
| 87 | + |
| 88 | + old = CheckpointPath("foo", epoch=3, step=5) |
| 89 | + new = CheckpointPath("foo", epoch=3, step=9) |
| 90 | + self.assertTrue(new.newer_than(old)) |
| 91 | + self.assertFalse(old.newer_than(new)) |
| 92 | + self.assertFalse(new == old) |
| 93 | + |
| 94 | + twin1 = CheckpointPath( |
| 95 | + "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0) |
| 96 | + ) |
| 97 | + almost_twin = CheckpointPath( |
| 98 | + "foo", epoch=2, step=5, metric_data=MetricData("bar", 2.0) |
| 99 | + ) |
| 100 | + |
| 101 | + self.assertFalse(twin1.newer_than(almost_twin)) |
| 102 | + self.assertFalse(almost_twin.newer_than(twin1)) |
| 103 | + self.assertFalse(twin1 == almost_twin) |
| 104 | + |
| 105 | + twin2 = CheckpointPath( |
| 106 | + "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0) |
| 107 | + ) |
| 108 | + self.assertTrue(twin1 == twin2) |
| 109 | + |
| 110 | + def test_compare_by_optimality(self) -> None: |
| 111 | + # not both metric aware |
| 112 | + ckpt1 = CheckpointPath("foo", epoch=0, step=1) |
| 113 | + ckpt2 = CheckpointPath("foo", epoch=1, step=1) |
| 114 | + ckpt3 = CheckpointPath( |
| 115 | + "foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0) |
| 116 | + ) |
| 117 | + for ckpt in [ckpt2, ckpt3]: |
| 118 | + with self.assertRaisesRegex( |
| 119 | + AssertionError, |
| 120 | + "Attempted to compare optimality of non metric-aware checkpoints", |
| 121 | + ): |
| 122 | + ckpt1.more_optimal_than(ckpt, mode="min") |
| 123 | + |
| 124 | + # tracking different metrics |
| 125 | + ckpt4 = CheckpointPath( |
| 126 | + "foo", epoch=1, step=1, metric_data=MetricData("baz", 1.0) |
| 127 | + ) |
| 128 | + with self.assertRaisesRegex( |
| 129 | + AssertionError, |
| 130 | + "Attempted to compare optimality of checkpoints tracking different metrics", |
| 131 | + ): |
| 132 | + ckpt3.more_optimal_than(ckpt4, mode="min") |
| 133 | + |
| 134 | + smaller = CheckpointPath( |
| 135 | + "foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0) |
| 136 | + ) |
| 137 | + larger = CheckpointPath( |
| 138 | + "foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0) |
| 139 | + ) |
| 140 | + self.assertTrue(larger.more_optimal_than(smaller, mode="max")) |
| 141 | + self.assertFalse(smaller.more_optimal_than(larger, mode="max")) |
| 142 | + self.assertTrue(smaller.more_optimal_than(larger, mode="min")) |
| 143 | + self.assertFalse(larger.more_optimal_than(smaller, mode="min")) |
0 commit comments