Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/ATen/native/quantized/FusedObsFakeQuant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_xpu(
const int64_t ch_axis,
bool per_row_fq,
bool symmetric_quant) {
const auto x_dim = x.dim();
TORCH_CHECK(
ch_axis < x.dim(),
"Error in fused_moving_avg_obs_fq_helper: ch_axis must be < "
"self.dim()");
ch_axis >= -x_dim && ch_axis < x_dim,
"Error in fused_moving_avg_obs_fq_helper: ch_axis ",
ch_axis,
" is out of range for tensor with ",
x_dim,
" dimensions");
const auto wrapped_ch_axis = ch_axis < 0 ? ch_axis + x_dim : ch_axis;

const auto x_contig = x.contiguous();
// Calculate the size of the dimension we need to quantize over,
Expand All @@ -55,13 +60,13 @@ std::tuple<at::Tensor, at::Tensor> fused_moving_avg_obs_fake_quant_xpu(
if (x.dim() != 2) {
auto res = DimVector(x.sizes());
std::iota(res.begin(), res.end(), 0);
res[ch_axis] = 0;
res[0] = ch_axis;
res[wrapped_ch_axis] = 0;
res[0] = wrapped_ch_axis;

y = x.permute(res);
y = y.flatten(1);
}
size = x.size(ch_axis);
size = x.size(wrapped_ch_axis);
if (running_min.numel() == 0) {
running_min.resize_(size).fill_(at::numeric_limits<float>::upper_bound());
running_max.resize_(size).fill_(at::numeric_limits<float>::lower_bound());
Expand Down
70 changes: 70 additions & 0 deletions test/repro/test_fused_obs_fake_quant_ch_axis_bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2020-2026 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0

# Owner(s): ["module: intel"]

import unittest

import torch
from torch.testing._internal.common_utils import run_tests, TestCase


def _make_inputs(device):
x = torch.randn(2, 3, 4, device=device)
observer_on = torch.tensor([1], dtype=torch.long, device=device)
fake_quant_on = torch.tensor([1], dtype=torch.long, device=device)
running_min = torch.tensor([], dtype=torch.float, device=device)
running_max = torch.tensor([], dtype=torch.float, device=device)
scale = torch.tensor([], dtype=torch.float, device=device)
zero_point = torch.tensor([], dtype=torch.int32, device=device)
return x, observer_on, fake_quant_on, running_min, running_max, scale, zero_point


@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
class TestFusedObsFakeQuantChAxisBounds(TestCase):
def test_large_negative_ch_axis_raises(self):
args = _make_inputs("xpu")
with self.assertRaisesRegex(RuntimeError, "out of range"):
torch._fused_moving_avg_obs_fq_helper(
*args,
averaging_const=0.01,
quant_min=0,
quant_max=255,
ch_axis=-1250999896764,
per_row_fake_quant=True,
symmetric_quant=False,
)

def test_negative_one_ch_axis_wraps(self):
args = _make_inputs("xpu")
torch._fused_moving_avg_obs_fq_helper(
*args,
averaging_const=0.01,
quant_min=0,
quant_max=255,
ch_axis=-1,
per_row_fake_quant=True,
symmetric_quant=False,
)

def test_positive_out_of_range_ch_axis_raises(self):
args = _make_inputs("xpu")
with self.assertRaisesRegex(RuntimeError, "out of range"):
torch._fused_moving_avg_obs_fq_helper(
*args,
averaging_const=0.01,
quant_min=0,
quant_max=255,
ch_axis=3,
per_row_fake_quant=True,
symmetric_quant=False,
)


if __name__ == "__main__":
run_tests()
Loading