Skip to content

Commit 3c7bfcf

Browse files
fanyunqianTracin
fanyunqian
authored andcommitted
[Fix] PoT Quantization for multi-gpus
[Fix]
1 parent 5a7bc23 commit 3c7bfcf

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

mqbench/observer.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def close(self):
6767
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
6868
r"""Calculates the quantization parameters."""
6969
scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
70-
if self.pot_scale:
71-
scale = pot_quantization(scale)
7270
scale.data = sync_tensor(scale).data
7371
zero_point.data = sync_tensor(zero_point).data
72+
if self.pot_scale:
73+
scale = pot_quantization(scale)
7474
return scale, zero_point
7575

7676
@torch.jit.export
@@ -456,14 +456,14 @@ def forward(self, x_orig):
456456

457457
def calculate_qparams(self):
458458
scale = 2 * self.tensor_norm / math.sqrt(self.quant_max)
459+
zero_point = torch.zeros_like(self.tensor_norm)
460+
sync_tensor(scale)
461+
sync_tensor(zero_point)
459462
if self.pot_scale:
460463
scale = pot_quantization(scale)
461-
zero_point = torch.zeros_like(self.tensor_norm)
462464
if not is_symmetric_quant(self.qscheme):
463465
if self.min_val >= 0.:
464466
zero_point = self.quant_min - torch.round(self.min_val / scale)
465-
sync_tensor(scale)
466-
sync_tensor(zero_point)
467467
return scale, zero_point
468468

469469

@@ -505,14 +505,14 @@ def forward(self, x_orig):
505505
def calculate_qparams(self):
506506
scale = torch.maximum((self.mean - 3 * self.std).abs(),
507507
(self.mean + 3 * self.std).abs()) / (self.quant_max - self.quant_min + 1)
508+
sync_tensor(scale)
509+
sync_tensor(zero_point)
508510
if self.pot_scale:
509511
scale = pot_quantization(scale)
510512
zero_point = torch.zeros_like(self.mean)
511513
if not is_symmetric_quant(self.qscheme):
512514
if self.min_val >= 0.:
513515
zero_point = self.quant_min - torch.round(self.min_val / scale)
514-
sync_tensor(scale)
515-
sync_tensor(zero_point)
516516
return scale, zero_point
517517

518518

0 commit comments

Comments
 (0)