File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed
checkpoint/orbax/checkpoint/_src/multihost Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change 1515"""Multislice utilities."""
1616
1717import functools
18+ import math
19+ import os
1820from typing import Any , Optional , Set , Union
1921
2022from absl import logging
@@ -139,6 +141,10 @@ def in_replica(
139141def get_device_memory () -> int :
140142 """Returns HBM capacity of the device on which the code is running(in bytes)."""
141143 device = jax .devices ()[0 ]
144+ if device .platform == 'cpu' :
145+ page_size = os .sysconf ('SC_PAGE_SIZE' )
146+ phys_pages = os .sysconf ('SC_PHYS_PAGES' )
147+ return int (page_size * phys_pages )
142148 if device .platform not in ('tpu' , 'gpu' ):
143149 raise ValueError ('Only select TPU and GPU devices are supported.' )
144150 hbm_memory = {
@@ -163,6 +169,9 @@ def get_device_memory() -> int:
163169
164170def get_leaf_memory_per_device (arr : jax .Array ) -> int :
165171 """Returns the memory usage of a sharded array per device (in bytes)."""
172+ if not arr .addressable_shards :
173+ shard_shape = arr .sharding .shard_shape (arr .shape )
174+ return math .prod (shard_shape ) * arr .dtype .itemsize
166175 shard = arr .addressable_shards [0 ]
167176 return shard .data .size * shard .data .itemsize
168177
You can’t perform that action at this time.
0 commit comments