Skip to content

Commit 3fc617b

Browse files
YIWENX14facebook-github-bot
authored andcommitted
primitive scale fix
Differential Revision: D74446877
1 parent 554cb60 commit 3fc617b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def _choose_qparams_affine(
948948
scale = torch.clamp(scale, min=eps)
949949
else:
950950
assert mapping_type == MappingType.ASYMMETRIC.name
951-
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
951+
scale = (max_val_pos - min_val_neg) / torch.tensor([float(quant_max - quant_min)], dtype=input.dtype, device=input.device)
952952
scale = torch.clamp(scale, min=eps)
953953
if zero_point_domain == ZeroPointDomain.NONE.name:
954954
zero_point = None

0 commit comments

Comments
 (0)