Skip to content

Commit 0d176be

Browse files
dzdangpytorchmergebot
authored andcommitted
[quant][improvement] Added quantized fill test for per channel quantized tensors
Summary: Previously, the quantization test suite only tested the fill operator for per tensor quantized tensors. This PR adds a test case for per channel quantized tensors. The existing `test_qtensor_fill`, which the newly introduced test function `test_qtensor_fill_per_channel` is based on, case had some ambiguous naming conventions. This PR renames some of those variables to be clearer. Test Plan: ``` python test/test_quantization.py test_qtensor_fill_per_channel ``` Pull Request resolved: pytorch#78661 Approved by: https://github.com/jerryzh168
1 parent 9da5def commit 0d176be

File tree

1 file changed

+41
-13
lines changed

1 file changed

+41
-13
lines changed

test/quantization/core/test_quantized_tensor.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -927,35 +927,63 @@ def test_clone(self):
927927
# Check to make sure the scale and zero_point has been copied.
928928
self.assertEqual(q, q2)
929929

930-
def test_qtensor_fill(self):
930+
def test_qtensor_fill_per_tensor(self):
931931
numel = 10
932932
scale = 0.5
933933
zero_point = 10
934934

935935
ones = torch.ones(numel).to(torch.float)
936936

937-
types = [torch.qint8, torch.quint8, torch.qint32]
938-
fills = [-1, 1, 2**32] # positive, negative, overflow
937+
qtypes = [torch.qint8, torch.quint8, torch.qint32]
938+
vals2fill = [-1, 1, 2**32] # positive, negative, overflow
939939

940940
# `fill_` uses `copy_(float)`, which doesn't support CUDA
941941
device = 'cpu'
942-
ones = ones.to(device)
943-
for qtype, fill_with in itertools.product(types, fills):
942+
for qtype, val2fill in itertools.product(qtypes, vals2fill):
944943
q_filled = torch._empty_affine_quantized(
945944
[numel], scale=scale, zero_point=zero_point, device=device,
946945
dtype=qtype)
947-
q_filled.fill_(fill_with)
948-
int_repr = torch.quantize_per_tensor(ones * fill_with, scale,
949-
zero_point, qtype)
950-
fill_with = int_repr.dequantize()
951-
int_repr = int_repr.int_repr()
952-
953-
self.assertEqual(q_filled.int_repr(), int_repr)
954-
self.assertEqual(q_filled.dequantize(), fill_with)
946+
q_filled.fill_(val2fill)
947+
# reference tensor for comparing q_filled
948+
q_ref = torch.quantize_per_tensor(ones * val2fill, scale,
949+
zero_point, qtype)
950+
self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
951+
self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
955952
# Make sure the scale and zero_point don't change
956953
self.assertEqual(q_filled.q_scale(), scale)
957954
self.assertEqual(q_filled.q_zero_point(), zero_point)
958955

956+
# adapted from test_qtensor_fill_per_tensor
957+
def test_qtensor_fill_per_channel(self):
958+
dims = [4, 5]
959+
axis = 0
960+
# adding a constant to avoid too small of a scale
961+
scales = torch.rand(dims[axis], dtype=torch.float64) + 0.1
962+
zero_points = torch.randint(low=0, high=10, size=(dims[axis], ))
963+
964+
ones = torch.ones(dims).to(torch.float)
965+
966+
qtypes = [torch.qint8, torch.quint8, torch.qint32]
967+
vals2fill = [-1, 1, 2**32] # positive, negative, overflow
968+
969+
devices = ["cpu", "cuda"] if TEST_CUDA else ["cpu"]
970+
for qtype, val2fill, device in itertools.product(qtypes, vals2fill, devices):
971+
scales = scales.to(device)
972+
zero_points = zero_points.to(device)
973+
ones = ones.to(device)
974+
q_filled = torch._empty_per_channel_affine_quantized(
975+
dims, scales=scales, zero_points=zero_points, device=device,
976+
axis=axis, dtype=qtype)
977+
q_filled.fill_(val2fill)
978+
# reference tensor for comparing q_filled
979+
q_ref = torch.quantize_per_channel(ones * val2fill, scales=scales,
980+
zero_points=zero_points, axis=axis, dtype=qtype)
981+
self.assertEqual(q_filled.int_repr(), q_ref.int_repr())
982+
self.assertEqual(q_filled.dequantize(), q_ref.dequantize())
983+
# Make sure the scale and zero_point don't change
984+
self.assertEqual(q_filled.q_per_channel_scales(), scales)
985+
self.assertEqual(q_filled.q_per_channel_zero_points(), zero_points)
986+
959987
@unittest.skipIf(not TEST_CUDA, "No gpu is available.")
960988
def test_qtensor_index_select_cuda(self):
961989
self._test_qtensor_index_select('cuda')

0 commit comments

Comments
 (0)