File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
checkpoint/orbax/checkpoint/_src/multihost Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 1515"""Multislice utilities."""
1616
1717import functools
18+ import math
19+ import os
1820from typing import Any , Optional , Set , Union
1921
2022from absl import logging
@@ -139,6 +141,10 @@ def in_replica(
139141def get_device_memory () -> int :
140142 """Returns HBM capacity of the device on which the code is running(in bytes)."""
141143 device = jax .devices ()[0 ]
144+ if device .platform == 'cpu' :
145+ page_size = os .sysconf ('SC_PAGE_SIZE' )
146+ phys_pages = os .sysconf ('SC_PHYS_PAGES' )
147+ return int (page_size * phys_pages )
142148 if device .platform not in ('tpu' , 'gpu' ):
143149 raise ValueError ('Only select TPU and GPU devices are supported.' )
144150 hbm_memory = {
@@ -163,8 +169,8 @@ def get_device_memory() -> int:
163169
164170def get_leaf_memory_per_device (arr : jax .Array ) -> int :
165171 """Returns the memory usage of a sharded array per device (in bytes)."""
166- shard = arr .addressable_shards [ 0 ]
167- return shard . data . size * shard . data .itemsize
172+ shard_shape = arr .sharding . shard_shape ( arr . shape )
173+ return math . prod ( shard_shape ) * arr . dtype .itemsize
168174
169175
170176def tree_memory_per_device (tree : tuple [jax .Array , ...] | jax .Array ) -> int :
You can’t perform that action at this time.
0 commit comments