Skip to content

Commit daec61e

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Normalize device kind strings when looking up HBM memory.
This change removes spaces from device kind strings before looking up HBM memory values. This ensures that device kinds like "TPU 7x" and "TPU7x" are treated the same, preventing lookup failures due to inconsistent spacing. PiperOrigin-RevId: 871477934
1 parent 46eea87 commit daec61e

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ handlers using `StepMetadata.item_handlers` and the global `HandlerTypeRegistry`
2222
if no args are provided.
2323
- `CompositeCheckpointHandler.metadata()` now returns `StepMetadata`.
2424

25+
### Fixed
26+
27+
- Fixed `get_device_memory` issue on TPU 7x devices where the device kind string
28+
was consistently reported without a space, causing a ValueError.
29+
2530
## [0.1.7] - 2022-03-29
2631

2732
### Added

checkpoint/orbax/checkpoint/_src/multihost/multislice.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def get_device_memory() -> int:
159159
'NVIDIA B200': int(183e9),
160160
'NVIDIA B300 SXM6 AC': int(275e9),
161161
}
162-
memory = hbm_memory.get(device.device_kind, None)
162+
# Remove spaces from the device kind to make the lookup robust.
163+
# For example, "TPU 7x" and "TPU7x" should both map to the same value.
164+
normalized_hbm_memory = {k.replace(' ', ''): v for k, v in hbm_memory.items()}
165+
memory = normalized_hbm_memory.get(device.device_kind.replace(' ', ''), None)
163166
if memory is None:
164167
raise ValueError(
165168
f'get_device_memory is not supported for {device.device_kind}.'

0 commit comments

Comments
 (0)