Skip to content

Commit 2520b09

Browse files
author
Orbax Authors
committed
#p2p Support CPU memory ops in multislice functions
PiperOrigin-RevId: 869818598
1 parent bc92c67 commit 2520b09

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Multislice utilities."""
1616

1717
import functools
18+
import math
19+
import os
1820
from typing import Any, Optional, Set, Union
1921

2022
from absl import logging
@@ -139,6 +141,10 @@ def in_replica(
139141
def get_device_memory() -> int:
140142
"""Returns HBM capacity of the device on which the code is running(in bytes)."""
141143
device = jax.devices()[0]
144+
if device.platform == 'cpu':
145+
page_size = os.sysconf('SC_PAGE_SIZE')
146+
phys_pages = os.sysconf('SC_PHYS_PAGES')
147+
return int(page_size * phys_pages)
142148
if device.platform not in ('tpu', 'gpu'):
143149
raise ValueError('Only select TPU and GPU devices are supported.')
144150
hbm_memory = {
@@ -163,8 +169,8 @@ def get_device_memory() -> int:
163169

164170
def get_leaf_memory_per_device(arr: jax.Array) -> int:
165171
"""Returns the memory usage of a sharded array per device (in bytes)."""
166-
shard = arr.addressable_shards[0]
167-
return shard.data.size * shard.data.itemsize
172+
shard_shape = arr.sharding.shard_shape(arr.shape)
173+
return math.prod(shard_shape) * arr.dtype.itemsize
168174

169175

170176
def tree_memory_per_device(tree: tuple[jax.Array, ...] | jax.Array) -> int:

0 commit comments

Comments
 (0)