@@ -1133,62 +1133,6 @@ def embedding_forward_4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
1133
1133
baseline_out = embedding_forward_4w (x2 , fq_embedding .weight )
1134
1134
torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
1135
1135
1136
- @unittest .skipIf (
1137
- not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1138
- )
1139
- def test_qat_prototype_bc (self ):
1140
- """
1141
- Just to make sure we can import all the old prototype paths.
1142
- We will remove this test in the near future when we actually break BC.
1143
- """
1144
- from torchao .quantization .prototype .qat import ( # noqa: F401, F811, I001
1145
- disable_4w_fake_quant ,
1146
- disable_8da4w_fake_quant ,
1147
- enable_4w_fake_quant ,
1148
- enable_8da4w_fake_quant ,
1149
- ComposableQATQuantizer ,
1150
- Int8DynActInt4WeightQATLinear ,
1151
- Int4WeightOnlyEmbeddingQATQuantizer ,
1152
- Int4WeightOnlyQATQuantizer ,
1153
- Int8DynActInt4WeightQATQuantizer ,
1154
- )
1155
- from torchao .quantization .prototype .qat ._module_swap_api import ( # noqa: F401, F811
1156
- disable_4w_fake_quant_module_swap ,
1157
- enable_4w_fake_quant_module_swap ,
1158
- disable_8da4w_fake_quant_module_swap ,
1159
- enable_8da4w_fake_quant_module_swap ,
1160
- Int4WeightOnlyQATQuantizerModuleSwap ,
1161
- Int8DynActInt4WeightQATQuantizerModuleSwap ,
1162
- )
1163
- from torchao .quantization .prototype .qat .affine_fake_quantized_tensor import ( # noqa: F401, F811
1164
- AffineFakeQuantizedTensor ,
1165
- to_affine_fake_quantized ,
1166
- )
1167
- from torchao .quantization .prototype .qat .api import ( # noqa: F401, F811
1168
- ComposableQATQuantizer ,
1169
- FakeQuantizeConfig ,
1170
- )
1171
- from torchao .quantization .prototype .qat .embedding import ( # noqa: F401, F811
1172
- FakeQuantizedEmbedding ,
1173
- Int4WeightOnlyEmbeddingQATQuantizer ,
1174
- Int4WeightOnlyEmbedding ,
1175
- Int4WeightOnlyQATEmbedding ,
1176
- )
1177
- from torchao .quantization .prototype .qat .fake_quantizer import ( # noqa: F401, F811
1178
- FakeQuantizer ,
1179
- )
1180
- from torchao .quantization .prototype .qat .linear import ( # noqa: F401, F811
1181
- disable_4w_fake_quant ,
1182
- disable_8da4w_fake_quant ,
1183
- enable_4w_fake_quant ,
1184
- enable_8da4w_fake_quant ,
1185
- FakeQuantizedLinear ,
1186
- Int4WeightOnlyQATLinear ,
1187
- Int4WeightOnlyQATQuantizer ,
1188
- Int8DynActInt4WeightQATLinear ,
1189
- Int8DynActInt4WeightQATQuantizer ,
1190
- )
1191
-
1192
1136
@unittest .skipIf (
1193
1137
not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
1194
1138
)
0 commit comments