|
26 | 26 | from simulai.optimization._adjusters import AnnealingWeights
|
27 | 27 |
|
28 | 28 | class LossBasics:
|
| 29 | + |
29 | 30 | def __init__(self):
|
30 | 31 | """
|
31 | 32 | Loss functions parent class
|
32 | 33 | """
|
33 | 34 | self.loss_states = None
|
| 35 | + self.tol = 1e-16 |
| 36 | + |
| 37 | + def _single_term_loss(self, res:torch.Tensor) -> torch.Tensor: |
| 38 | + |
| 39 | + return torch.square(res) |
| 40 | + |
| 41 | + def _two_term_loss(self, res_a:torch.Tensor, res_b:torch.Tensor) -> torch.Tensor: |
| 42 | + |
| 43 | + return torch.square(res_a - res_b) |
| 44 | + |
| 45 | + def _two_term_log_loss(self, res_a:torch.Tensor, res_b:torch.Tensor) -> torch.Tensor: |
| 46 | + |
| 47 | + if torch.all(res_a <= self.tol) or torch.all(res_b <= self.tol): |
| 48 | + |
| 49 | + return self._two_term_loss(res_a, res_b) |
| 50 | + else: |
| 51 | + return torch.square(torch.log(res_a) - torch.log(res_b)) |
34 | 52 |
|
35 | 53 | # Choosing the kind of multiplication to be done for each
|
36 | 54 | # type of lambda penalties and regularization terms
|
@@ -545,7 +563,7 @@ def _data_loss(
|
545 | 563 | target_split = torch.split(target_data_tensor, self.split_dim, dim=-1)
|
546 | 564 |
|
547 | 565 | data_losses = [
|
548 |
| - self.loss_evaluator(out_split - tgt_split) / (self.norm_evaluator(tgt_split) or torch.tensor(1.0).to(self.device)) |
| 566 | + self.loss_evaluator_data((out_split, tgt_split)) / (self.norm_evaluator(tgt_split) or torch.tensor(1.0).to(self.device)) |
549 | 567 | for i, (out_split, tgt_split) in enumerate(zip(output_split, target_split))
|
550 | 568 | ]
|
551 | 569 |
|
@@ -583,7 +601,7 @@ def _data_loss_adaptive(
|
583 | 601 | operator=self.operator)
|
584 | 602 |
|
585 | 603 | data_losses = [
|
586 |
| - weights[i]*self.loss_evaluator(out_split - tgt_split) |
| 604 | + weights[i]*self.loss_evaluator_data((out_split, tgt_split)) |
587 | 605 | for i, (out_split, tgt_split) in enumerate(zip(output_split, target_split))
|
588 | 606 | ]
|
589 | 607 |
|
@@ -771,6 +789,7 @@ def __call__(
|
771 | 789 | residual_weights_estimator: Callable = None,
|
772 | 790 | data_weights_estimator: Callable = None,
|
773 | 791 | use_mean: bool = True,
|
| 792 | + use_data_log: bool = False, |
774 | 793 | ) -> Callable:
|
775 | 794 | self.residual = residual
|
776 | 795 |
|
@@ -861,10 +880,20 @@ def __call__(
|
861 | 880 | else:
|
862 | 881 | self.extra_data = self._no_extra_data
|
863 | 882 |
|
| 883 | + if use_data_log == True: |
| 884 | + self.inner_square = self._two_term_log_loss |
| 885 | + else: |
| 886 | + self.inner_square = self._two_term_loss |
| 887 | + |
864 | 888 | if use_mean == True:
|
865 |
| - self.loss_evaluator = lambda res: torch.mean(torch.square((res))) |
| 889 | + self.loss_evaluator = lambda res: torch.mean(self._single_term_loss(res)) |
866 | 890 | else:
|
867 |
| - self.loss_evaluator = lambda res: torch.sum(torch.square((res))) |
| 891 | + self.loss_evaluator = lambda res: torch.sum(self._single_term_loss(res)) |
| 892 | + |
| 893 | + if use_mean == True: |
| 894 | + self.loss_evaluator_data = lambda res: torch.mean(self.inner_square(*res)) |
| 895 | + else: |
| 896 | + self.loss_evaluator_data = lambda res: torch.sum(self.inner_square(*res)) |
868 | 897 |
|
869 | 898 | # Relative norm or not
|
870 | 899 | if relative == True:
|
|
0 commit comments