Skip to content

Commit 60e9f84

Browse files
mxberlotOrbax Authors
authored andcommitted
Internal change
PiperOrigin-RevId: 869217857
1 parent 60b50ba commit 60e9f84

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2026 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Checkpoint storage implementations."""
16+
17+
import abc
18+
import dataclasses
19+
import enum
20+
21+
from absl import logging
22+
from etils import epath
23+
from orbax.checkpoint._src.path import atomicity_types
24+
25+
26+
@dataclasses.dataclass(frozen=True)
27+
class CheckpointPathMetadata:
28+
"""Internal representation of checkpoint path metadata.
29+
30+
Attributes:
31+
path: The file system path of the checkpoint.
32+
status: The status of the checkpoint.
33+
version: The version of the checkpoint with an index and step number. (e.g.
34+
'1.step_1')
35+
tags: A list of tags associated with the checkpoint. Currently only
36+
supported for TFHub paths, for other paths this field will be `None`.
37+
"""
38+
39+
class Status(enum.Enum):
40+
COMMITTED = 1
41+
UNCOMMITTED = 2
42+
43+
path: str
44+
status: Status
45+
version: str | None
46+
tags: set[str] | None = None
47+
48+
49+
@dataclasses.dataclass(frozen=True)
50+
class CheckpointFilter:
51+
"""Criteria for filtering checkpoints.
52+
53+
TODO: b/466312058 This class will contain fields for filtering checkpoints by
54+
various criteria.
55+
"""
56+
57+
58+
@dataclasses.dataclass(frozen=True)
59+
class CheckpointReadOptions:
60+
"""Options for reading checkpoints.
61+
62+
Attributes:
63+
filter: Optional filter criteria for selecting checkpoints.
64+
enable_strong_reads: If True, enables strong read consistency when querying
65+
checkpoints. This may have performance implications but ensures the most
66+
up-to-date results.
67+
"""
68+
69+
filter: CheckpointFilter | None = None
70+
enable_strong_reads: bool = False
71+
72+
73+
class StorageBackend(abc.ABC):
74+
"""An abstract base class for a storage backend.
75+
76+
This class defines a common interface for managing checkpoint paths in
77+
different file systems.
78+
"""
79+
80+
@abc.abstractmethod
81+
def list_checkpoints(self, base_path: str) -> list[CheckpointPathMetadata]:
82+
"""Lists checkpoints for a given base path and version pattern."""
83+
raise NotImplementedError('Subclasses must provide implementation')
84+
85+
@abc.abstractmethod
86+
def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]:
87+
"""Returns a TemporaryPath class for the storage backend."""
88+
raise NotImplementedError('Subclasses must provide implementation')
89+
90+
@abc.abstractmethod
91+
def delete_checkpoint(self, checkpoint_path: str | epath.PathLike) -> None:
92+
"""Deletes a checkpoint from the storage backend."""
93+
raise NotImplementedError('Subclasses must provide implementation')
94+
95+
96+
class GCSStorageBackend(StorageBackend):
97+
"""A StorageBackend implementation for GCS (Google Cloud Storage).
98+
99+
# TODO(b/425293362): Implement this class.
100+
"""
101+
102+
def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]:
103+
"""Returns the final checkpoint path directly."""
104+
raise NotImplementedError(
105+
'get_temporary_path_class is not yet implemented for GCSStorageBackend.'
106+
)
107+
108+
def list_checkpoints(self, base_path: str) -> list[CheckpointPathMetadata]:
109+
"""Lists checkpoints for a given base path and version pattern."""
110+
raise NotImplementedError(
111+
'list_checkpoints is not yet implemented for GCSStorageBackend.'
112+
)
113+
114+
def delete_checkpoint(self, checkpoint_path: str | epath.PathLike) -> None:
115+
"""Deletes the checkpoint at the given path."""
116+
raise NotImplementedError(
117+
'delete_checkpoint is not yet implemented for GCSStorageBackend.'
118+
)
119+
120+
121+
class LocalStorageBackend(StorageBackend):
122+
"""A LocalStorageBackend implementation for local file systems.
123+
124+
# TODO(b/425293362): Implement this class.
125+
"""
126+
127+
def get_temporary_path_class(self) -> type[atomicity_types.TemporaryPath]:
128+
"""Returns the final checkpoint path directly."""
129+
raise NotImplementedError(
130+
'get_temporary_path_class is not yet implemented for'
131+
' LocalStorageBackend.'
132+
)
133+
134+
def list_checkpoints(self, base_path: str) -> list[CheckpointPathMetadata]:
135+
"""Lists checkpoints for a given base path and version pattern."""
136+
raise NotImplementedError(
137+
'list_checkpoints is not yet implemented for LocalStorageBackend.'
138+
)
139+
140+
def delete_checkpoint(self, checkpoint_path: str | epath.PathLike) -> None:
141+
"""Deletes the checkpoint at the given path."""
142+
try:
143+
epath.Path(checkpoint_path).rmtree()
144+
logging.info('Removed old checkpoint (%s)', checkpoint_path)
145+
except OSError:
146+
logging.exception('Failed to remove checkpoint (%s)', checkpoint_path)

0 commit comments

Comments
 (0)