@@ -142,34 +142,16 @@ def in_replica(
142142
143143def get_device_memory () -> int :
144144 """Returns HBM capacity of the device on which the code is running(in bytes)."""
145- device = jax .devices ()[0 ]
145+ device = jax .local_devices ()[0 ]
146146 if device .platform == 'cpu' :
147147 page_size = os .sysconf ('SC_PAGE_SIZE' )
148148 phys_pages = os .sysconf ('SC_PHYS_PAGES' )
149149 return int (page_size * phys_pages )
150+
150151 if device .platform not in ('tpu' , 'gpu' ):
151152 raise ValueError ('Only select TPU and GPU devices are supported.' )
152- hbm_memory = {
153- 'TPU v3' : int (16e9 ), # two cores per chip each with 16 GB HBM
154- 'TPU v4' : int (32e9 ), # one megacore per chip with 32 GB HBM
155- 'TPU v5 lite' : int (16e9 ), # one core per chip with 16 GB HBM
156- 'TPU v5' : int (96e9 ), # one megacore per chip with 96 GB HBM
157- 'TPU v6 lite' : int (32e9 ), # one core per chip with 32 GB HBM
158- 'TPU 7x' : int (96e9 ), # two cores per chip each with 96 GB HBM
159- 'NVIDIA H100 80GB HBM3' : int (80e9 ),
160- 'NVIDIA H200' : int (144e9 ),
161- 'NVIDIA B200' : int (183e9 ),
162- 'NVIDIA B300 SXM6 AC' : int (275e9 ),
163- }
164- # Remove spaces from the device kind to make the lookup robust.
165- # For example, "TPU 7x" and "TPU7x" should both map to the same value.
166- normalized_hbm_memory = {k .replace (' ' , '' ): v for k , v in hbm_memory .items ()}
167- memory = normalized_hbm_memory .get (device .device_kind .replace (' ' , '' ), None )
168- if memory is None :
169- raise ValueError (
170- f'get_device_memory is not supported for { device .device_kind } .'
171- )
172- return memory
153+
154+ return device .memory_stats ()['bytes_limit' ]
173155
174156
175157def get_leaf_memory_per_device (arr : jax .Array ) -> int :
0 commit comments