Skip to content

Commit 17e199d

Browse files
Merge pull request #153 from IBM/extend/loss_types
Separating the loss evaluators into two classes
2 parents 86c32bc + c5c6e45 commit 17e199d

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

simulai/optimization/_losses.py

+33-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,29 @@
2626
from simulai.optimization._adjusters import AnnealingWeights
2727

2828
class LossBasics:
29+
2930
def __init__(self):
3031
"""
3132
Loss functions parent class
3233
"""
3334
self.loss_states = None
35+
self.tol = 1e-16
36+
37+
def _single_term_loss(self, res:torch.Tensor) -> torch.Tensor:
38+
39+
return torch.square(res)
40+
41+
def _two_term_loss(self, res_a:torch.Tensor, res_b:torch.Tensor) -> torch.Tensor:
42+
43+
return torch.square(res_a - res_b)
44+
45+
def _two_term_log_loss(self, res_a:torch.Tensor, res_b:torch.Tensor) -> torch.Tensor:
46+
47+
if torch.all(res_a <= self.tol) or torch.all(res_b <= self.tol):
48+
49+
return self._two_term_loss(res_a, res_b)
50+
else:
51+
return torch.square(torch.log(res_a) - torch.log(res_b))
3452

3553
# Choosing the kind of multiplication to be done for each
3654
# type of lambda penalties and regularization terms
@@ -545,7 +563,7 @@ def _data_loss(
545563
target_split = torch.split(target_data_tensor, self.split_dim, dim=-1)
546564

547565
data_losses = [
548-
self.loss_evaluator(out_split - tgt_split) / (self.norm_evaluator(tgt_split) or torch.tensor(1.0).to(self.device))
566+
self.loss_evaluator_data((out_split, tgt_split)) / (self.norm_evaluator(tgt_split) or torch.tensor(1.0).to(self.device))
549567
for i, (out_split, tgt_split) in enumerate(zip(output_split, target_split))
550568
]
551569

@@ -583,7 +601,7 @@ def _data_loss_adaptive(
583601
operator=self.operator)
584602

585603
data_losses = [
586-
weights[i]*self.loss_evaluator(out_split - tgt_split)
604+
weights[i]*self.loss_evaluator_data((out_split, tgt_split))
587605
for i, (out_split, tgt_split) in enumerate(zip(output_split, target_split))
588606
]
589607

@@ -771,6 +789,7 @@ def __call__(
771789
residual_weights_estimator: Callable = None,
772790
data_weights_estimator: Callable = None,
773791
use_mean: bool = True,
792+
use_data_log: bool = False,
774793
) -> Callable:
775794
self.residual = residual
776795

@@ -861,10 +880,20 @@ def __call__(
861880
else:
862881
self.extra_data = self._no_extra_data
863882

883+
if use_data_log == True:
884+
self.inner_square = self._two_term_log_loss
885+
else:
886+
self.inner_square = self._two_term_loss
887+
864888
if use_mean == True:
865-
self.loss_evaluator = lambda res: torch.mean(torch.square((res)))
889+
self.loss_evaluator = lambda res: torch.mean(self._single_term_loss(res))
866890
else:
867-
self.loss_evaluator = lambda res: torch.sum(torch.square((res)))
891+
self.loss_evaluator = lambda res: torch.sum(self._single_term_loss(res))
892+
893+
if use_mean == True:
894+
self.loss_evaluator_data = lambda res: torch.mean(self.inner_square(*res))
895+
else:
896+
self.loss_evaluator_data = lambda res: torch.sum(self.inner_square(*res))
868897

869898
# Relative norm or not
870899
if relative == True:

0 commit comments

Comments
 (0)