Skip to content

Commit 4e9f2dc

Browse files
fanyunqianTracin
fanyunqian
authored andcommitted
[Refactor] Re-write the (EMA)MSEObserver to accelerate the calibration
phase.
1 parent 37752b1 commit 4e9f2dc

File tree

1 file changed

+47
-9
lines changed

1 file changed

+47
-9
lines changed

mqbench/observer.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -528,11 +528,12 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
528528
ch_axis, pot_scale, factory_kwargs)
529529
self.p = p
530530

531-
def lp_loss(self, pred, tgt):
531+
def lp_loss(self, pred, tgt, dim=None):
532532
"""
533533
loss function measured in L_p Norm
534534
"""
535-
return (pred - tgt).abs().pow(self.p).mean()
535+
return (pred - tgt).abs().pow(self.p).mean(dim) if dim else (pred - tgt).abs().pow(self.p).mean()
536+
536537

537538
def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80):
538539
best_score = 1e+10
@@ -552,6 +553,26 @@ def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80
552553
best_min, best_max = new_min, new_max
553554
return best_min, best_max
554555

556+
def mse_perchannel(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80, ch_axis=0):
557+
assert x_min.shape == x_max.shape
558+
assert ch_axis >= 0, f'{ch_axis}'
559+
best_score = 1e+10 * torch.ones_like(x_min)
560+
best_min, best_max = x_min.clone(), x_max.clone()
561+
reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis])
562+
for i in range(iter):
563+
new_min = x_min * (1.0 - (i * 0.01))
564+
new_max = x_max * (1.0 - (i * 0.01))
565+
scale, zero_point = self._calculate_qparams(new_min, new_max)
566+
x_q = torch.fake_quantize_per_channel_affine(
567+
x, scale, zero_point.long(), ch_axis,
568+
self.quant_min, self.quant_max)
569+
score = self.lp_loss(x_q, x, reduce_dim)
570+
update_idx = (score < best_score)
571+
best_score[update_idx] = score[update_idx]
572+
best_min[update_idx] = new_min[update_idx]
573+
best_max[update_idx] = new_max[update_idx]
574+
return best_min, best_max
575+
555576
def forward(self, x_orig):
556577
r"""Records the running minimum and maximum of ``x``."""
557578
if x_orig.numel() == 0:
@@ -568,8 +589,7 @@ def forward(self, x_orig):
568589
x_channel = x.permute(new_axis_list)
569590
y = torch.flatten(x_channel, start_dim=1)
570591
min_val_cur, max_val_cur = torch._aminmax(y, 1)
571-
for ch, val in enumerate(min_val_cur):
572-
min_val_cur[ch], max_val_cur[ch] = self.mse(x_channel[ch], min_val_cur[ch], max_val_cur[ch], iter=80)
592+
min_val_cur, max_val_cur = self.mse_perchannel(x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis)
573593

574594
self.min_val = torch.min(self.min_val, min_val_cur)
575595
self.max_val = torch.max(self.max_val, max_val_cur)
@@ -588,11 +608,11 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
588608
self.ema_ratio = ema_ratio
589609
self.p = p
590610

591-
def lp_loss(self, pred, tgt):
611+
def lp_loss(self, pred, tgt, dim=None):
592612
"""
593613
loss function measured in L_p Norm
594614
"""
595-
return (pred - tgt).abs().pow(self.p).mean()
615+
return (pred - tgt).abs().pow(self.p).mean(dim) if dim else (pred - tgt).abs().pow(self.p).mean()
596616

597617
def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80):
598618
best_score = 1e+10
@@ -612,6 +632,26 @@ def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80
612632
best_min, best_max = new_min, new_max
613633
return best_min, best_max
614634

635+
def mse_perchannel(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80, ch_axis=0):
636+
assert x_min.shape == x_max.shape
637+
assert ch_axis >= 0, f'{ch_axis}'
638+
best_score = 1e+10 * torch.ones_like(x_min)
639+
best_min, best_max = x_min.clone(), x_max.clone()
640+
reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis])
641+
for i in range(iter):
642+
new_min = x_min * (1.0 - (i * 0.01))
643+
new_max = x_max * (1.0 - (i * 0.01))
644+
scale, zero_point = self._calculate_qparams(new_min, new_max)
645+
x_q = torch.fake_quantize_per_channel_affine(
646+
x, scale, zero_point.long(), ch_axis,
647+
self.quant_min, self.quant_max)
648+
score = self.lp_loss(x_q, x, reduce_dim)
649+
update_idx = (score < best_score)
650+
best_score[update_idx] = score[update_idx]
651+
best_min[update_idx] = new_min[update_idx]
652+
best_max[update_idx] = new_max[update_idx]
653+
return best_min, best_max
654+
615655
def forward(self, x_orig):
616656
r"""Records the running minimum and maximum of ``x``."""
617657
if x_orig.numel() == 0:
@@ -628,9 +668,7 @@ def forward(self, x_orig):
628668
x_channel = x.permute(new_axis_list)
629669
y = torch.flatten(x_channel, start_dim=1)
630670
min_val_cur, max_val_cur = torch._aminmax(y, 1)
631-
for ch, val in enumerate(min_val_cur):
632-
min_val_cur[ch], max_val_cur[ch] = self.mse(x_channel[ch], min_val_cur[ch],
633-
max_val_cur[ch], iter=80)
671+
min_val_cur, max_val_cur = self.mse_perchannel(x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis)
634672

635673
if self.max_val.numel() <= 1 and self.max_val.isinf():
636674
self.min_val = min_val_cur

0 commit comments

Comments
 (0)