Skip to content

Commit 5e86a82

Browse files
authored
Merge pull request #243 from pollytur/add_adaptive
adaptive regularization added
2 parents a2055cd + d266351 commit 5e86a82

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

neuralpredictors/layers/readouts/gaussian.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,24 @@ def __init__(
277277
feature_reg_weight=None,
278278
gamma_readout=None, # depricated, use feature_reg_weight instead
279279
return_weighted_features=False,
280+
regularizer_type="l1",
281+
gamma_sigma=0.1,
280282
**kwargs,
281283
) -> None:
282284
super().__init__()
283285
self.feature_reg_weight = self.resolve_deprecated_gamma_readout(feature_reg_weight, gamma_readout, default=1.0)
284286
self.mean_activity = mean_activity
285287
# determines whether the Gaussian is isotropic or not
286288
self.gauss_type = gauss_type
289+
self._regularizer_type = regularizer_type
290+
291+
if self._regularizer_type == "adaptive_log_norm":
292+
self.gamma_sigma = gamma_sigma
293+
self.adaptive_neuron_reg_coefs = torch.nn.Parameter(
294+
torch.normal(mean=torch.ones(1, outdims), std=torch.ones(1, outdims))
295+
)
296+
elif self._regularizer_type != "l1":
297+
raise ValueError(f"regularizer_type should be 'l1' or 'adaptive_log_norm' but got {self._regularizer_type}")
287298

288299
if init_mu_range > 1.0 or init_mu_range <= 0.0 or init_sigma <= 0.0:
289300
raise ValueError("either init_mu_range doesn't belong to [0.0, 1.0] or init_sigma_range is non-positive")
@@ -380,6 +391,26 @@ def feature_l1(self, reduction="sum", average=None):
380391
else:
381392
return 0
382393

394+
def adaptive_feature_l1_lognorm(self, reduction="sum", average=None):
395+
if self._original_features:
396+
features = self.adaptive_neuron_reg_coefs.abs() * self.features
397+
features_regularization = (
398+
self.apply_reduction(features.abs(), reduction=reduction, average=average) * self.feature_reg_weight
399+
)
400+
# adaptive_neuron_reg_coefs (betas) are supposted to be from lognorm distribution
401+
coef_prior = 1 / (self.gamma_sigma**2) * ((torch.log(self.adaptive_neuron_reg_coefs.abs()) ** 2).sum())
402+
return regularization_loss + coef_prior
403+
else:
404+
return 0
405+
406+
def regularizer(self, reduction="sum", average=None):
407+
if self._regularizer_type == "l1":
408+
return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight
409+
elif self._regularizer_type == "adaptive_log_norm":
410+
return self.adaptive_feature_l1_lognorm(reduction=reduction, average=average)
411+
else:
412+
raise NotImplementedError(f"Regularizer_type {self._regularizer_type} is not implemented")
413+
383414
def regularizer(self, reduction="sum", average=None):
384415
return self.feature_l1(reduction=reduction, average=average) * self.feature_reg_weight
385416

0 commit comments

Comments
 (0)