Skip to content

Commit 8db9f39

Browse files
authored
Revert 2 PRs for quant fix (#484)
Revert #480 #481
1 parent 05f91e8 commit 8db9f39

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

python/hidet/graph/ops/quant/symmetric.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import Union, List
1313
from hidet import ir
1414
from hidet.ir.type import DataType
15-
from hidet.ir.dtypes import int32
1615
from hidet.ir.expr import cast, if_then_else
1716
from hidet.ir.compute.primitives import TensorNode, compute
1817
from hidet.ir import primitives as prim
@@ -37,9 +36,7 @@ def __init__(self, w: TensorNode, quant_type: DataType, dims: Union[int, List[in
3736

3837
def scale_weight(*indices):
3938
scale_indices = [indices[i] for i in range(len(indices)) if not i in dims]
40-
# Have to cast to int32 first because there are several ways convert bf16 to int8
41-
cast_to_int = cast(prim.round(w[indices] / scale[scale_indices]), int32)
42-
return cast(cast_to_int, quant_type)
39+
return cast(prim.round(w[indices] / scale[scale_indices]), quant_type)
4340

4441
wq = compute(name='quantize', shape=w.shape, fcompute=scale_weight)
4542
super().__init__(

0 commit comments

Comments
 (0)