@@ -528,11 +528,12 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
528
528
ch_axis , pot_scale , factory_kwargs )
529
529
self .p = p
530
530
531
- def lp_loss (self , pred , tgt ):
531
+ def lp_loss (self , pred , tgt , dim = None ):
532
532
"""
533
533
loss function measured in L_p Norm
534
534
"""
535
- return (pred - tgt ).abs ().pow (self .p ).mean ()
535
+ return (pred - tgt ).abs ().pow (self .p ).mean (dim ) if dim else (pred - tgt ).abs ().pow (self .p ).mean ()
536
+
536
537
537
538
def mse (self , x : torch .Tensor , x_min : torch .Tensor , x_max : torch .Tensor , iter = 80 ):
538
539
best_score = 1e+10
@@ -552,6 +553,26 @@ def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80
552
553
best_min , best_max = new_min , new_max
553
554
return best_min , best_max
554
555
556
+ def mse_perchannel (self , x : torch .Tensor , x_min : torch .Tensor , x_max : torch .Tensor , iter = 80 , ch_axis = 0 ):
557
+ assert x_min .shape == x_max .shape
558
+ assert ch_axis >= 0 , f'{ ch_axis } '
559
+ best_score = 1e+10 * torch .ones_like (x_min )
560
+ best_min , best_max = x_min .clone (), x_max .clone ()
561
+ reduce_dim = tuple ([i for i in range (len (x .shape )) if i != ch_axis ])
562
+ for i in range (iter ):
563
+ new_min = x_min * (1.0 - (i * 0.01 ))
564
+ new_max = x_max * (1.0 - (i * 0.01 ))
565
+ scale , zero_point = self ._calculate_qparams (new_min , new_max )
566
+ x_q = torch .fake_quantize_per_channel_affine (
567
+ x , scale , zero_point .long (), ch_axis ,
568
+ self .quant_min , self .quant_max )
569
+ score = self .lp_loss (x_q , x , reduce_dim )
570
+ update_idx = (score < best_score )
571
+ best_score [update_idx ] = score [update_idx ]
572
+ best_min [update_idx ] = new_min [update_idx ]
573
+ best_max [update_idx ] = new_max [update_idx ]
574
+ return best_min , best_max
575
+
555
576
def forward (self , x_orig ):
556
577
r"""Records the running minimum and maximum of ``x``."""
557
578
if x_orig .numel () == 0 :
@@ -568,8 +589,7 @@ def forward(self, x_orig):
568
589
x_channel = x .permute (new_axis_list )
569
590
y = torch .flatten (x_channel , start_dim = 1 )
570
591
min_val_cur , max_val_cur = torch ._aminmax (y , 1 )
571
- for ch , val in enumerate (min_val_cur ):
572
- min_val_cur [ch ], max_val_cur [ch ] = self .mse (x_channel [ch ], min_val_cur [ch ], max_val_cur [ch ], iter = 80 )
592
+ min_val_cur , max_val_cur = self .mse_perchannel (x , min_val_cur , max_val_cur , iter = 80 , ch_axis = self .ch_axis )
573
593
574
594
self .min_val = torch .min (self .min_val , min_val_cur )
575
595
self .max_val = torch .max (self .max_val , max_val_cur )
@@ -588,11 +608,11 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
588
608
self .ema_ratio = ema_ratio
589
609
self .p = p
590
610
591
- def lp_loss (self , pred , tgt ):
611
+ def lp_loss (self , pred , tgt , dim = None ):
592
612
"""
593
613
loss function measured in L_p Norm
594
614
"""
595
- return (pred - tgt ).abs ().pow (self .p ).mean ()
615
+ return (pred - tgt ).abs ().pow (self .p ).mean (dim ) if dim else ( pred - tgt ). abs (). pow ( self . p ). mean ( )
596
616
597
617
def mse (self , x : torch .Tensor , x_min : torch .Tensor , x_max : torch .Tensor , iter = 80 ):
598
618
best_score = 1e+10
@@ -612,6 +632,26 @@ def mse(self, x: torch.Tensor, x_min: torch.Tensor, x_max: torch.Tensor, iter=80
612
632
best_min , best_max = new_min , new_max
613
633
return best_min , best_max
614
634
635
+ def mse_perchannel (self , x : torch .Tensor , x_min : torch .Tensor , x_max : torch .Tensor , iter = 80 , ch_axis = 0 ):
636
+ assert x_min .shape == x_max .shape
637
+ assert ch_axis >= 0 , f'{ ch_axis } '
638
+ best_score = 1e+10 * torch .ones_like (x_min )
639
+ best_min , best_max = x_min .clone (), x_max .clone ()
640
+ reduce_dim = tuple ([i for i in range (len (x .shape )) if i != ch_axis ])
641
+ for i in range (iter ):
642
+ new_min = x_min * (1.0 - (i * 0.01 ))
643
+ new_max = x_max * (1.0 - (i * 0.01 ))
644
+ scale , zero_point = self ._calculate_qparams (new_min , new_max )
645
+ x_q = torch .fake_quantize_per_channel_affine (
646
+ x , scale , zero_point .long (), ch_axis ,
647
+ self .quant_min , self .quant_max )
648
+ score = self .lp_loss (x_q , x , reduce_dim )
649
+ update_idx = (score < best_score )
650
+ best_score [update_idx ] = score [update_idx ]
651
+ best_min [update_idx ] = new_min [update_idx ]
652
+ best_max [update_idx ] = new_max [update_idx ]
653
+ return best_min , best_max
654
+
615
655
def forward (self , x_orig ):
616
656
r"""Records the running minimum and maximum of ``x``."""
617
657
if x_orig .numel () == 0 :
@@ -628,9 +668,7 @@ def forward(self, x_orig):
628
668
x_channel = x .permute (new_axis_list )
629
669
y = torch .flatten (x_channel , start_dim = 1 )
630
670
min_val_cur , max_val_cur = torch ._aminmax (y , 1 )
631
- for ch , val in enumerate (min_val_cur ):
632
- min_val_cur [ch ], max_val_cur [ch ] = self .mse (x_channel [ch ], min_val_cur [ch ],
633
- max_val_cur [ch ], iter = 80 )
671
+ min_val_cur , max_val_cur = self .mse_perchannel (x , min_val_cur , max_val_cur , iter = 80 , ch_axis = self .ch_axis )
634
672
635
673
if self .max_val .numel () <= 1 and self .max_val .isinf ():
636
674
self .min_val = min_val_cur
0 commit comments