Skip to content

Commit 8823ce5

Browse files
author
Orbax Authors
committed
Use device.memory_stats for accelerator HBM size.
PiperOrigin-RevId: 871943828
1 parent fb9aed3 commit 8823ce5

File tree

1 file changed

+4
-22
lines changed

1 file changed

+4
-22
lines changed

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -142,34 +142,16 @@ def in_replica(
142142

143143
def get_device_memory() -> int:
144144
"""Returns HBM capacity of the device on which the code is running(in bytes)."""
145-
device = jax.devices()[0]
145+
device = jax.local_devices()[0]
146146
if device.platform == 'cpu':
147147
page_size = os.sysconf('SC_PAGE_SIZE')
148148
phys_pages = os.sysconf('SC_PHYS_PAGES')
149149
return int(page_size * phys_pages)
150+
150151
if device.platform not in ('tpu', 'gpu'):
151152
raise ValueError('Only select TPU and GPU devices are supported.')
152-
hbm_memory = {
153-
'TPU v3': int(16e9), # two cores per chip each with 16 GB HBM
154-
'TPU v4': int(32e9), # one megacore per chip with 32 GB HBM
155-
'TPU v5 lite': int(16e9), # one core per chip with 16 GB HBM
156-
'TPU v5': int(96e9), # one megacore per chip with 96 GB HBM
157-
'TPU v6 lite': int(32e9), # one core per chip with 32 GB HBM
158-
'TPU 7x': int(96e9), # two cores per chip each with 96 GB HBM
159-
'NVIDIA H100 80GB HBM3': int(80e9),
160-
'NVIDIA H200': int(144e9),
161-
'NVIDIA B200': int(183e9),
162-
'NVIDIA B300 SXM6 AC': int(275e9),
163-
}
164-
# Remove spaces from the device kind to make the lookup robust.
165-
# For example, "TPU 7x" and "TPU7x" should both map to the same value.
166-
normalized_hbm_memory = {k.replace(' ', ''): v for k, v in hbm_memory.items()}
167-
memory = normalized_hbm_memory.get(device.device_kind.replace(' ', ''), None)
168-
if memory is None:
169-
raise ValueError(
170-
f'get_device_memory is not supported for {device.device_kind}.'
171-
)
172-
return memory
153+
154+
return device.memory_stats()['bytes_limit']
173155

174156

175157
def get_leaf_memory_per_device(arr: jax.Array) -> int:

0 commit comments

Comments
 (0)