Skip to content

Commit b0a604a

Browse files
author
Orbax Authors
committed
Internal change.
PiperOrigin-RevId: 872326871
1 parent 0c8d5ab commit b0a604a

File tree

2 files changed

+119
-11
lines changed

2 files changed

+119
-11
lines changed

checkpoint/orbax/checkpoint/_src/arrays/fragments.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ class methods:
3131
index in `indices`. If FS is concrete then the fragment values will be
3232
slices of x. If FS is `AbstractFragments` then x may additionally be
3333
a `jax.ShapeDtypeStruct`.
34+
- `like(fragments, x)` takes an existing Fragments instance and returns a new
35+
instance of FS, with the same shape and dtype and fragment indices, but
36+
with the fragment values replaced by slices of x. If FS is
37+
`AbstractFragments` then x must be `None` and may be omitted.
38+
39+
Use `FS.of()` and friends to make instances out of arraylike things.
40+
Use `FS.like()` to convert between different kinds of Fragments.
41+
Use `{np, jnp}.asarray(fragments)` to make arraylike things out of (full)
42+
Fragments.
3443
"""
3544
# TODO(b/465196209): Remove when support for Python 3.10 is dropped.
3645
from __future__ import annotations
@@ -478,11 +487,26 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
478487
fragments = [cls.FRAGMENT_T(index=index) for index in indices]
479488
return cls(x.shape, x.dtype, fragments)
480489

490+
@classmethod
491+
def like(
492+
cls: type[FS],
493+
fragments: _GenericFragments[Any],
494+
value: Literal[None] = None,
495+
) -> FS:
496+
del value
497+
return cls(
498+
shape=fragments.shape,
499+
dtype=fragments.dtype,
500+
fragments=[
501+
cls.FRAGMENT_T(index=f.index) for f in fragments.fragments
502+
],
503+
)
504+
481505

482506
@dataclasses.dataclass(frozen=True, init=False)
483-
class NpFragments(_GenericFragments[NpFragment]):
484-
"""A collection of fragments whose values are of type `np.ndarray`."""
485-
FRAGMENT_T = NpFragment
507+
class _ConcreteFragments(_GenericFragments[Fconcrete]):
508+
"""A collection of concrete fragments."""
509+
FRAGMENT_T: ClassVar[type[Fconcrete]] # The type of fragment values.
486510

487511
@classmethod
488512
def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
@@ -491,19 +515,37 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
491515
fragments = [cls.FRAGMENT_T(index=i, value=x[i]) for i in indices]
492516
return cls(x.shape, x.dtype, fragments)
493517

518+
@classmethod
519+
def like(
520+
cls: type[FS], fragments: _GenericFragments[Any], value: Aconcrete
521+
) -> FS:
522+
_check_fragment_value_type(value, cls.FRAGMENT_T.ARRAY_T)
523+
if fragments.shape != value.shape or fragments.dtype != value.dtype:
524+
raise ValueError(
525+
f'Fragments type {fragments.dtype}[{fragments.shape}] does'
526+
f' not match value type {value.dtype}[{value.shape}].'
527+
)
528+
return cls(
529+
shape=fragments.shape,
530+
dtype=fragments.dtype,
531+
fragments=[
532+
cls.FRAGMENT_T(index=f.index, value=value[f.index])
533+
for f in fragments.fragments
534+
],
535+
)
536+
494537

495538
@dataclasses.dataclass(frozen=True, init=False)
496-
class JaxFragments(_GenericFragments[JaxFragment]):
539+
class NpFragments(_ConcreteFragments[NpFragment]):
540+
"""A collection of fragments whose values are of type `np.ndarray`."""
541+
FRAGMENT_T = NpFragment
542+
543+
544+
@dataclasses.dataclass(frozen=True, init=False)
545+
class JaxFragments(_ConcreteFragments[JaxFragment]):
497546
"""A collection of fragments whose values are of type `jax.Array`."""
498547
FRAGMENT_T = JaxFragment
499548

500-
@classmethod
501-
def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
502-
"""Returns a Fragments with one fragment for each index."""
503-
_check_fragment_value_type(x, cls.FRAGMENT_T.ARRAY_T)
504-
fragments = [cls.FRAGMENT_T(index=i, value=x[i]) for i in indices]
505-
return cls(x.shape, x.dtype, fragments)
506-
507549

508550
# Extra names for backwards compatibility. Most loading and saving code still
509551
# wants to deal with NumPy arrays so that views and operations on them

checkpoint/orbax/checkpoint/_src/arrays/fragments_test.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,72 @@ def test_abstract_fragments_of(self, value_fn):
927927
fragment_t(index=np.s_[0:2:1, 2:3:1]),
928928
])
929929

930+
@parameterized.named_parameters(
931+
('np_array', NpFragments),
932+
('jnp_array', JaxFragments),
933+
)
934+
def test_abstract_fragments_like_concrete_fragments(
935+
self, template_fragments_t: FragmentsT
936+
):
937+
fragments_t = AbstractFragments
938+
shape = (2, 3)
939+
dtype = np.dtype(np.float32)
940+
fragment_t = fragments_t.FRAGMENT_T
941+
942+
template_np_api = template_fragments_t.FRAGMENT_T.NP_API
943+
template = template_fragments_t(
944+
shape=shape,
945+
dtype=dtype,
946+
fragments=[
947+
template_fragments_t.FRAGMENT_T(
948+
index=np.s_[0:2:1, 0:1:1], value=template_np_api.ones((2, 1))
949+
),
950+
template_fragments_t.FRAGMENT_T(
951+
index=np.s_[0:2:1, 2:3:1], value=template_np_api.ones((2, 1))
952+
),
953+
],
954+
)
955+
956+
fs = fragments_t.like(template)
957+
958+
self.assertEqual(fs.shape, shape)
959+
self.assertEqual(fs.dtype, dtype)
960+
self.assertEqual(fs.fragments, [
961+
fragment_t(index=np.s_[0:2:1, 0:1:1], value=None),
962+
fragment_t(index=np.s_[0:2:1, 2:3:1], value=None),
963+
])
964+
965+
@parameterized.named_parameters(
966+
('np_array', NpFragments),
967+
('jnp_array', JaxFragments),
968+
)
969+
def test_concrete_fragments_like_abstract_fragments(
970+
self, fragments_t: FragmentsT
971+
):
972+
shape = (2, 3)
973+
dtype = np.dtype(np.float32)
974+
fragment_t = fragments_t.FRAGMENT_T
975+
np_api = fragment_t.NP_API
976+
977+
template = AbstractFragments(
978+
shape=shape,
979+
dtype=dtype,
980+
fragments=[
981+
AbstractFragment(index=np.s_[0:2:1, 0:1:1]),
982+
AbstractFragment(index=np.s_[0:2:1, 2:3:1]),
983+
],
984+
)
985+
value = np_api.ones(shape, dtype=dtype)
986+
987+
fs = fragments_t.like(template, value)
988+
989+
self.assertEqual(fs.shape, shape)
990+
self.assertEqual(fs.dtype, dtype)
991+
self.assertEqual(fs.fragments, [
992+
fragment_t(index=np.s_[0:2:1, 0:1:1], value=value[0:2:1, 0:1:1]),
993+
fragment_t(index=np.s_[0:2:1, 2:3:1], value=value[0:2:1, 2:3:1]),
994+
])
995+
930996

931997
@parameterized.named_parameters(
932998
('abstract_fragments', AbstractFragments),

0 commit comments

Comments
 (0)