Skip to content

Commit d3b39cf

Browse files
authored
Disable DSD and fix bitsandbytes test (#2314)
1 parent d4465c8 commit d3b39cf

File tree

2 files changed

+32
-27
lines changed

2 files changed

+32
-27
lines changed

tests/torchtune/modules/low_precision/test_nf4_linear.py

+27-23
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
try:
8-
import bitsandbytes as bnb
9-
10-
bnb_installed = True
11-
except ImportError:
12-
bnb_installed = False
137
import pytest
148
import torch
159
from torchao.dtypes.nf4tensor import NF4Tensor
@@ -22,19 +16,6 @@ def random():
2216
set_seed(31)
2317

2418

25-
def _build_bnb_linear(input_weight):
26-
"""
27-
Builds a bnb.nn.LinearNF4 from a given input weight
28-
"""
29-
param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4")
30-
bnb_linear = bnb.nn.LinearNF4(
31-
input_weight.size(0), input_weight.size(1), bias=False
32-
)
33-
bnb_linear.weight = param
34-
bnb_linear.cuda()
35-
return bnb_linear
36-
37-
3819
class TestNF4Linear:
3920
"""
4021
Class for testing our NF4Linear implementation.
@@ -88,18 +69,29 @@ def test_backward_dtype(self, dtype):
8869
assert inp.grad is not None and inp.grad.dtype == dtype
8970
assert nf4_linear.weight.grad is None
9071

91-
@pytest.mark.skipif(not bnb_installed, reason="bitsandbytes is not installed")
9272
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
9373
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
9474
def test_nf4_reconstruction_vs_bnb(self, dtype):
9575
"""
9676
Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when
9777
reconstructing the respective original weights.
9878
"""
79+
try:
80+
import bitsandbytes as bnb
81+
except ImportError:
82+
pytest.skip("bitsandbytes is not installed")
83+
return
84+
9985
dim = 512
10086
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype)
10187
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
102-
bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight)
88+
89+
param = bnb.nn.Params4bit(orig_weight, requires_grad=False, quant_type="nf4")
90+
bnb_nf4_linear = bnb.nn.LinearNF4(
91+
orig_weight.size(0), orig_weight.size(1), bias=False
92+
)
93+
bnb_nf4_linear.weight = param
94+
bnb_nf4_linear.cuda()
10395

10496
# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65
10597
bnb_reconstruction = bnb_nf4_linear(
@@ -110,18 +102,30 @@ def test_nf4_reconstruction_vs_bnb(self, dtype):
110102
bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2
111103
)
112104

113-
@pytest.mark.skipif(not bnb_installed, reason="bitsandbytes is not installed")
114105
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
115106
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
116107
def test_nf4_bnb_linear(self, dtype):
117108
"""
118109
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
119110
error compared to a bf16 linear is not more than BNB's implementation.
120111
"""
112+
try:
113+
import bitsandbytes as bnb
114+
except ImportError:
115+
pytest.skip("bitsandbytes is not installed")
116+
return
117+
121118
dim = 512
122119
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype)
123120
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
124-
bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight)
121+
122+
param = bnb.nn.Params4bit(orig_weight, requires_grad=False, quant_type="nf4")
123+
bnb_nf4_linear = bnb.nn.LinearNF4(
124+
orig_weight.size(0), orig_weight.size(1), bias=False
125+
)
126+
bnb_nf4_linear.weight = param
127+
bnb_nf4_linear.cuda()
128+
125129
bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype)
126130

127131
inp = torch.randn(2, 512, dtype=dtype, device="cuda")

torchtune/training/_distributed.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,18 @@
3636
from torchtune.modules.peft import get_adapter_state_dict
3737
from torchtune.utils import get_device, get_logger
3838
from torchtune.utils._logging import deprecated
39-
from torchtune.utils._version import torch_version_ge
4039

4140
_log: logging.Logger = get_logger()
4241

4342

4443
_valid_distributed_single_node_nnodes = ["1:1", "1"]
4544

4645
torch_version = torch.__version__
47-
_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = (
48-
"dev" not in torch_version and torch_version_ge("2.6.0")
49-
) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")
46+
# TODO: Fix issues with DSD before uncommenting. See #2313 and #2277.
47+
# _DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = (
48+
# "dev" not in torch_version and torch_version_ge("2.6.0")
49+
# ) or ("dev" in torch_version and torch_version.split("dev")[1] >= "20241220")
50+
_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = False
5051

5152

5253
def _get_sharding_strategy(strategy: str) -> ShardingStrategy:

0 commit comments

Comments
 (0)