Skip to content

Commit c370348

Browse files
authored
Remove old torchao imports, require 0.7.0+ (#2513)
1 parent 1241231 commit c370348

File tree

3 files changed

+14
-39
lines changed

3 files changed

+14
-39
lines changed

tests/recipes/test_qat_lora_finetune_distributed.py

-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
safe_torch_load,
3636
SHARD_FNAME,
3737
)
38-
from torchtune.training.quantization import _torchao_0_7_supported
3938

4039

4140
class TestQATLoRAFinetuneDistributedRecipe:
@@ -63,7 +62,6 @@ def _fetch_expected_loss_values(self, model_type):
6362
"micro_batch_size, gradient_accumulation_steps, should_compile",
6463
[(4, 1, True), (1, 4, False)],
6564
)
66-
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
6765
def test_loss(
6866
self,
6967
micro_batch_size,
@@ -116,7 +114,6 @@ def test_loss(
116114
("llama3/8B_qat_lora", "llama3", "tune", False),
117115
],
118116
)
119-
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
120117
def test_training_state_on_resume(
121118
self,
122119
config,
@@ -217,7 +214,6 @@ def test_training_state_on_resume(
217214
],
218215
)
219216
@gpu_test(gpu_count=2)
220-
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
221217
def test_save_and_load_merged_weights(
222218
self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch
223219
):

tests/torchtune/modules/peft/test_lora.py

-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torchtune import training
1616
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
1717
from torchtune.modules.peft import LoRALinear, QATLoRALinear
18-
from torchtune.training.quantization import _torchao_0_7_supported
1918
from torchtune.training.seed import set_seed
2019

2120

@@ -237,7 +236,6 @@ def test_quantized_state_dict(self, dtype):
237236
lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data
238237
)
239238

240-
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
241239
def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None:
242240
lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
243241
qat_lora_linear = QATLoRALinear.from_lora_linear(lora_linear)

torchtune/training/quantization.py

+14-33
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,26 @@
77
from typing import Callable, Optional
88

99
from torch import nn
10-
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
11-
1210

13-
try:
14-
# torchao 0.7+
15-
from torchao.dtypes import TensorCoreTiledLayout
16-
except ImportError:
17-
# torchao 0.6 and before
18-
from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout
11+
from torchao.dtypes import TensorCoreTiledLayout
1912

2013
from torchao.quantization import (
2114
int4_weight_only,
2215
int8_dynamic_activation_int4_weight,
2316
quantize_,
2417
)
2518

26-
try:
27-
# torchao 0.7+
28-
from torchao.quantization.qat import (
29-
Int4WeightOnlyQATQuantizer,
30-
Int8DynActInt4WeightQATQuantizer,
31-
)
32-
from torchao.quantization.qat.linear import (
33-
disable_4w_fake_quant,
34-
disable_8da4w_fake_quant,
35-
enable_4w_fake_quant,
36-
enable_8da4w_fake_quant,
37-
)
38-
except ImportError:
39-
# torchao 0.6 and before
40-
from torchao.quantization.prototype.qat import (
41-
disable_4w_fake_quant,
42-
disable_8da4w_fake_quant,
43-
enable_4w_fake_quant,
44-
enable_8da4w_fake_quant,
45-
Int4WeightOnlyQATQuantizer,
46-
Int8DynActInt4WeightQATQuantizer,
47-
)
19+
from torchao.quantization.qat import (
20+
Int4WeightOnlyQATQuantizer,
21+
Int8DynActInt4WeightQATQuantizer,
22+
)
23+
from torchao.quantization.qat.linear import (
24+
disable_4w_fake_quant,
25+
disable_8da4w_fake_quant,
26+
enable_4w_fake_quant,
27+
enable_8da4w_fake_quant,
28+
)
29+
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
4830

4931

5032
__all__ = [
@@ -58,11 +40,10 @@
5840
]
5941

6042

61-
_torchao_0_7_supported = True
6243
try:
6344
from torchao.quantization import qat # noqa: F401
64-
except ImportError:
65-
_torchao_0_7_supported = False
45+
except ImportError as e:
46+
raise ValueError("Need torchao version 0.7.0+") from e
6647

6748
_quantizer_to_mode = {}
6849
_quantizer_mode_to_disable_fake_quant = {}

0 commit comments

Comments
 (0)