Skip to content

[optim] Fix bug when default dtype is BF16 #2286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions test/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from torchao.optim.subclass_fp8 import OptimStateFp8
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_7,
get_available_devices,
Expand Down Expand Up @@ -128,8 +127,6 @@ class TestOptim(TestCase):
@skip_if_rocm("ROCm enablement in progress")
def test_optim_smoke(self, optim_name, dtype, device):
if optim_name.endswith("Fp8") and device == "cuda":
if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
if torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 CUDA requires compute capability >= 8.9")

Expand Down Expand Up @@ -166,6 +163,30 @@ def test_optim_smoke(self, optim_name, dtype, device):
for p1, p2 in zip(model.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1)

@parametrize("optim_name", ["Adam8bit", "Adam4bit", "AdamFp8"])
@parametrize("device", _DEVICES)
def test_optim_default_dtype_bf16(self, optim_name, device):
if optim_name.endswith("Fp8") and device == "cuda":
if torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 CUDA requires compute capability >= 8.9")

old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)

try:
model = nn.Sequential(nn.Linear(32, 256), nn.ReLU(), nn.Linear(256, 32))
model.to(device=device)
optimizer = getattr(optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device)
loss = model(x).sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()

finally:
torch.set_default_dtype(old_dtype)

# aten.slice is required for dcp.load() when world size changes i.e. re-sharding
# however, it's cumbersome to test it directly, since we would need to run distributed
# test 2 times with different world size, and persist checkpoint across the 2 runs.
Expand All @@ -178,8 +199,6 @@ def test_subclass_slice(self, subclass, shape, device):
if subclass == OptimStateFp8:
if device == "cpu" and len(shape) > 1 and not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("fill_cpu not implemented for Float8_e4m3fn for torch<2.5")
if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 CUDA requires compute capability >= 8.9")

Expand Down
6 changes: 3 additions & 3 deletions torchao/optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, sha
assert codes.dtype is torch.uint8
assert codes.ndim == 1 # flattened buffer
assert scale.ndim == 1
assert qmap.dtype is torch.float32
self.codes = codes
self.scale = scale
self.qmap = qmap
Expand Down Expand Up @@ -101,9 +102,8 @@ def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None):

codes = torch.zeros(n_elems // 2, dtype=torch.uint8, device=device)
scale = torch.zeros(n_elems // block_size, device=device)
qmap = torch.tensor(
get_qmap_signed() if signed else get_qmap_unsigned(), device=device
)
qmap_list = get_qmap_signed() if signed else get_qmap_unsigned()
qmap = torch.tensor(qmap_list, dtype=torch.float32, device=device)
return cls(codes, scale, qmap, signed, shape)

def __repr__(self):
Expand Down
6 changes: 3 additions & 3 deletions torchao/optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
"""
assert codes.dtype is torch.uint8
assert scale.ndim == 1
assert qmap.dtype is torch.float32
self.codes = codes
self.scale = scale
self.qmap = qmap
Expand Down Expand Up @@ -89,9 +90,8 @@ def dequantize(self, output_dtype=None):
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
qmap = torch.tensor(
get_qmap_signed() if signed else get_qmap_unsigned(), device=device
)
qmap_list = get_qmap_signed() if signed else get_qmap_unsigned()
qmap = torch.tensor(qmap_list, dtype=torch.float32, device=device)
return cls(codes, scale, qmap, signed)

def __repr__(self):
Expand Down
Loading