@@ -67,10 +67,10 @@ def close(self):
67
67
def calculate_qparams (self ) -> Tuple [torch .Tensor , torch .Tensor ]:
68
68
r"""Calculates the quantization parameters."""
69
69
scale , zero_point = self ._calculate_qparams (self .min_val , self .max_val )
70
- if self .pot_scale :
71
- scale = pot_quantization (scale )
72
70
scale .data = sync_tensor (scale ).data
73
71
zero_point .data = sync_tensor (zero_point ).data
72
+ if self .pot_scale :
73
+ scale = pot_quantization (scale )
74
74
return scale , zero_point
75
75
76
76
@torch .jit .export
@@ -456,14 +456,14 @@ def forward(self, x_orig):
456
456
457
457
def calculate_qparams (self ):
458
458
scale = 2 * self .tensor_norm / math .sqrt (self .quant_max )
459
+ zero_point = torch .zeros_like (self .tensor_norm )
460
+ sync_tensor (scale )
461
+ sync_tensor (zero_point )
459
462
if self .pot_scale :
460
463
scale = pot_quantization (scale )
461
- zero_point = torch .zeros_like (self .tensor_norm )
462
464
if not is_symmetric_quant (self .qscheme ):
463
465
if self .min_val >= 0. :
464
466
zero_point = self .quant_min - torch .round (self .min_val / scale )
465
- sync_tensor (scale )
466
- sync_tensor (zero_point )
467
467
return scale , zero_point
468
468
469
469
@@ -505,14 +505,14 @@ def forward(self, x_orig):
505
505
def calculate_qparams (self ):
506
506
scale = torch .maximum ((self .mean - 3 * self .std ).abs (),
507
507
(self .mean + 3 * self .std ).abs ()) / (self .quant_max - self .quant_min + 1 )
508
+ sync_tensor (scale )
509
+ sync_tensor (zero_point )
508
510
if self .pot_scale :
509
511
scale = pot_quantization (scale )
510
512
zero_point = torch .zeros_like (self .mean )
511
513
if not is_symmetric_quant (self .qscheme ):
512
514
if self .min_val >= 0. :
513
515
zero_point = self .quant_min - torch .round (self .min_val / scale )
514
- sync_tensor (scale )
515
- sync_tensor (zero_point )
516
516
return scale , zero_point
517
517
518
518
0 commit comments