@@ -31,6 +31,15 @@ class methods:
3131 index in `indices`. If FS is concrete then the fragment values will be
3232 slices of x. If FS is `AbstractFragments` then x may additionally be
3333 a `jax.ShapeDtypeStruct`.
34+ - `like(fragments, x)` takes an existing Fragments instance and returns a new
35+ instance of FS, with the same shape and dtype and fragment indices, but
36+ with the fragment values replaced by slices of x. If FS is
37+ `AbstractFragments` then x must be `None` and may be omitted.
38+
39+ Use `FS.of()` and friends to make instances out of arraylike things.
40+ Use `FS.like()` to convert between different kinds of Fragments.
41+ Use `{np, jnp}.asarray(fragments)` to make arraylike things out of (full)
42+ Fragments.
3443"""
3544# TODO(b/465196209): Remove when support for Python 3.10 is dropped.
3645from __future__ import annotations
@@ -478,11 +487,26 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
478487 fragments = [cls .FRAGMENT_T (index = index ) for index in indices ]
479488 return cls (x .shape , x .dtype , fragments )
480489
490+ @classmethod
491+ def like (
492+ cls : type [FS ],
493+ fragments : _GenericFragments [Any ],
494+ value : Literal [None ] = None ,
495+ ) -> FS :
496+ del value
497+ return cls (
498+ shape = fragments .shape ,
499+ dtype = fragments .dtype ,
500+ fragments = [
501+ cls .FRAGMENT_T (index = f .index ) for f in fragments .fragments
502+ ],
503+ )
504+
481505
482506@dataclasses .dataclass (frozen = True , init = False )
483- class NpFragments (_GenericFragments [NpFragment ]):
484- """A collection of fragments whose values are of type `np.ndarray` ."""
485- FRAGMENT_T = NpFragment
507+ class _ConcreteFragments (_GenericFragments [Fconcrete ]):
508+ """A collection of concrete fragments ."""
509+ FRAGMENT_T : ClassVar [ type [ Fconcrete ]] # The type of fragment values.
486510
487511 @classmethod
488512 def _of (cls : type [FS ], x : Any , * , indices : Sequence [Index ]) -> FS :
@@ -491,19 +515,37 @@ def _of(cls: type[FS], x: Any, *, indices: Sequence[Index]) -> FS:
491515 fragments = [cls .FRAGMENT_T (index = i , value = x [i ]) for i in indices ]
492516 return cls (x .shape , x .dtype , fragments )
493517
518+ @classmethod
519+ def like (
520+ cls : type [FS ], fragments : _GenericFragments [Any ], value : Aconcrete
521+ ) -> FS :
522+ _check_fragment_value_type (value , cls .FRAGMENT_T .ARRAY_T )
523+ if fragments .shape != value .shape or fragments .dtype != value .dtype :
524+ raise ValueError (
525+ f'Fragments type { fragments .dtype } [{ fragments .shape } ] does'
526+ f' not match value type { value .dtype } [{ value .shape } ].'
527+ )
528+ return cls (
529+ shape = fragments .shape ,
530+ dtype = fragments .dtype ,
531+ fragments = [
532+ cls .FRAGMENT_T (index = f .index , value = value [f .index ])
533+ for f in fragments .fragments
534+ ],
535+ )
536+
494537
495538@dataclasses .dataclass (frozen = True , init = False )
496- class JaxFragments (_GenericFragments [JaxFragment ]):
539+ class NpFragments (_ConcreteFragments [NpFragment ]):
540+ """A collection of fragments whose values are of type `np.ndarray`."""
541+ FRAGMENT_T = NpFragment
542+
543+
544+ @dataclasses .dataclass (frozen = True , init = False )
545+ class JaxFragments (_ConcreteFragments [JaxFragment ]):
497546 """A collection of fragments whose values are of type `jax.Array`."""
498547 FRAGMENT_T = JaxFragment
499548
500- @classmethod
501- def _of (cls : type [FS ], x : Any , * , indices : Sequence [Index ]) -> FS :
502- """Returns a Fragments with one fragment for each index."""
503- _check_fragment_value_type (x , cls .FRAGMENT_T .ARRAY_T )
504- fragments = [cls .FRAGMENT_T (index = i , value = x [i ]) for i in indices ]
505- return cls (x .shape , x .dtype , fragments )
506-
507549
508550# Extra names for backwards compatibility. Most loading and saving code still
509551# wants to deal with NumPy arrays so that views and operations on them
0 commit comments