Skip to content

Commit bfc6963

Browse files
author
Orbax Authors
committed
Internal change
PiperOrigin-RevId: 738393752
1 parent 8926361 commit bfc6963

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

checkpoint/orbax/checkpoint/_src/arrays/fragments.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,26 @@ def slice(
178178
stop = out.stop[:] = np.minimum(out.stop, slice_shape)
179179
if not (start < stop).all():
180180
return None
181-
if (value := self.value) is None:
182-
return out
183-
else:
184-
value_fragment = Fragment(
185-
np_index=np.stack([
186-
np.maximum(self.start, np_index[:, 0]),
187-
np.minimum(self.stop, np_index[:, 1]),
188-
np_index[:, 2],
189-
], axis=1)
190-
).offset_by(-self.start)
191-
out_value = value[value_fragment.index or ...]
192-
return dataclasses.replace(out, value=out_value)
181+
return dataclasses.replace(
182+
out, value=self.slice_of_value(np_index)
183+
) if self.value is not None else out
184+
185+
def slice_of_value(
186+
self,
187+
new_np_idx: NpIndex,
188+
) -> np.ndarray:
189+
"""Returns a slice of `value`."""
190+
start = self.start
191+
stop = self.stop
192+
# This is just a convenient way to construct the required tuple of slices.
193+
f = Fragment(
194+
np_index=np.stack([
195+
np.maximum(start, new_np_idx[:, 0]),
196+
np.minimum(stop, new_np_idx[:, 1]),
197+
new_np_idx[:, 2],
198+
], axis=1)
199+
).offset_by(-start)
200+
return self.value[f.index or ...]
193201

194202

195203
@dataclasses.dataclass(frozen=True)

0 commit comments

Comments
 (0)