@@ -277,13 +277,24 @@ def __init__(
277277 feature_reg_weight = None ,
278278 gamma_readout = None , # depricated, use feature_reg_weight instead
279279 return_weighted_features = False ,
280+ regularizer_type = "l1" ,
281+ gamma_sigma = 0.1 ,
280282 ** kwargs ,
281283 ) -> None :
282284 super ().__init__ ()
283285 self .feature_reg_weight = self .resolve_deprecated_gamma_readout (feature_reg_weight , gamma_readout , default = 1.0 )
284286 self .mean_activity = mean_activity
285287 # determines whether the Gaussian is isotropic or not
286288 self .gauss_type = gauss_type
289+ self ._regularizer_type = regularizer_type
290+
291+ if self ._regularizer_type == "adaptive_log_norm" :
292+ self .gamma_sigma = gamma_sigma
293+ self .adaptive_neuron_reg_coefs = torch .nn .Parameter (
294+ torch .normal (mean = torch .ones (1 , outdims ), std = torch .ones (1 , outdims ))
295+ )
296+ elif self ._regularizer_type != "l1" :
297+ raise ValueError (f"regularizer_type should be 'l1' or 'adaptive_log_norm' but got { self ._regularizer_type } " )
287298
288299 if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma <= 0.0 :
289300 raise ValueError ("either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive" )
@@ -380,6 +391,26 @@ def feature_l1(self, reduction="sum", average=None):
380391 else :
381392 return 0
382393
394+ def adaptive_feature_l1_lognorm (self , reduction = "sum" , average = None ):
395+ if self ._original_features :
396+ features = self .adaptive_neuron_reg_coefs .abs () * self .features
397+ features_regularization = (
398+ self .apply_reduction (features .abs (), reduction = reduction , average = average ) * self .feature_reg_weight
399+ )
400+ # adaptive_neuron_reg_coefs (betas) are supposted to be from lognorm distribution
401+ coef_prior = 1 / (self .gamma_sigma ** 2 ) * ((torch .log (self .adaptive_neuron_reg_coefs .abs ()) ** 2 ).sum ())
402+ return regularization_loss + coef_prior
403+ else :
404+ return 0
405+
406+ def regularizer (self , reduction = "sum" , average = None ):
407+ if self ._regularizer_type == "l1" :
408+ return self .feature_l1 (reduction = reduction , average = average ) * self .feature_reg_weight
409+ elif self ._regularizer_type == "adaptive_log_norm" :
410+ return self .adaptive_feature_l1_lognorm (reduction = reduction , average = average )
411+ else :
412+ raise NotImplementedError (f"Regularizer_type { self ._regularizer_type } is not implemented" )
413+
383414 def regularizer (self , reduction = "sum" , average = None ):
384415 return self .feature_l1 (reduction = reduction , average = average ) * self .feature_reg_weight
385416
0 commit comments