@@ -89,26 +89,22 @@ def _linear_bf16_act_uint4_weight_float_zero_impl(input_tensor, weight_tensor, b
89
89
return y .to (orig_dtype )
90
90
91
91
92
- def _linear_bf16_act_uint4_weight_int8_zero_check (input_tensor , weight_tensor , bias ):
92
+ def _linear_fp_act_uint4_weight_int8_zero_check (input_tensor , weight_tensor , bias ):
93
93
return (
94
- # input is native bfloat16 tensor
95
94
not is_traceable_wrapper_subclass (input_tensor )
96
- and input_tensor .dtype == torch .bfloat16
97
95
and
98
96
# weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor
99
97
isinstance (weight_tensor , AffineQuantizedTensor )
100
98
and _aqt_is_xpu_layout_uint4 (weight_tensor )
101
- and weight_tensor .dtype == torch .bfloat16
102
99
and len (weight_tensor .shape ) == 2
103
100
and weight_tensor .zero_point_domain == ZeroPointDomain .INT
104
101
and weight_tensor .tensor_impl .scale_and_zero is None
105
- and weight_tensor .tensor_impl .scale .dtype == torch .bfloat16
106
102
and weight_tensor .tensor_impl .zero .dtype == torch .int8
107
103
and isinstance (weight_tensor ._layout , Int4XPULayout )
108
104
)
109
105
110
106
111
- def _linear_bf16_act_uint4_weight_int8_zero_impl (input_tensor , weight_tensor , bias ):
107
+ def _linear_fp_act_uint4_weight_int8_zero_impl (input_tensor , weight_tensor , bias ):
112
108
assert weight_tensor .block_size [0 ] == 1 , (
113
109
f"Requires groupwise quantization, got block_size: { weight_tensor .block_size } "
114
110
)
@@ -129,7 +125,7 @@ def _linear_bf16_act_uint4_weight_int8_zero_impl(input_tensor, weight_tensor, bi
129
125
orig_act_size = act_mat .size ()
130
126
orig_dtype = act_mat .dtype
131
127
132
- act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ]). to ( torch . bfloat16 )
128
+ act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ])
133
129
134
130
# groupwise int4 quantization
135
131
groupsize = weight_tensor .block_size [1 ]
0 commit comments