Skip to content

Commit b1df975

Browse files
author
Orbax Authors
committed
#p2p Support CPU memory ops in multislice functions
PiperOrigin-RevId: 869791115
1 parent bc92c67 commit b1df975

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Multislice utilities."""
1616

1717
import functools
18+
import math
19+
import os
1820
from typing import Any, Optional, Set, Union
1921

2022
from absl import logging
@@ -139,6 +141,10 @@ def in_replica(
139141
def get_device_memory() -> int:
140142
"""Returns HBM capacity of the device on which the code is running(in bytes)."""
141143
device = jax.devices()[0]
144+
if device.platform == 'cpu':
145+
page_size = os.sysconf('SC_PAGE_SIZE')
146+
phys_pages = os.sysconf('SC_PHYS_PAGES')
147+
return int(page_size * phys_pages)
142148
if device.platform not in ('tpu', 'gpu'):
143149
raise ValueError('Only select TPU and GPU devices are supported.')
144150
hbm_memory = {
@@ -163,6 +169,9 @@ def get_device_memory() -> int:
163169

164170
def get_leaf_memory_per_device(arr: jax.Array) -> int:
165171
"""Returns the memory usage of a sharded array per device (in bytes)."""
172+
if not arr.addressable_shards:
173+
shard_shape = arr.sharding.shard_shape(arr.shape)
174+
return math.prod(shard_shape) * arr.dtype.itemsize
166175
shard = arr.addressable_shards[0]
167176
return shard.data.size * shard.data.itemsize
168177

0 commit comments

Comments
 (0)