-
Notifications
You must be signed in to change notification settings - Fork 101
Expand file tree
/
Copy pathequation_condition_base.py
More file actions
51 lines (42 loc) · 2.03 KB
/
equation_condition_base.py
File metadata and controls
51 lines (42 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Module for the EquationConditionBase class."""
from pina._src.condition.condition_base import ConditionBase
class EquationConditionBase(ConditionBase):
"""
Base class for conditions that involve an equation.
This class provides the :meth:`evaluate` method, which computes the
non-aggregated residual of the equation given the input samples and a
solver. It is intended to be subclassed by conditions that define an
``equation`` attribute, such as
:class:`~pina.condition.DomainEquationCondition` and
:class:`~pina.condition.InputEquationCondition`.
"""
def evaluate(self, batch, solver, loss):
"""
Evaluate the equation residual on the given batch using the solver.
This method computes the non-aggregated, element-wise residual of the
equation. It performs a forward pass of the solver's model on the
input samples and then evaluates the equation residual. The returned
tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the
per-sample residual values.
:param batch: The batch containing the ``input`` entry.
:type batch: dict | _DataManager
:param solver: The solver containing the model and any additional
parameters (e.g., unknown parameters for inverse problems).
:type solver: ~pina.solver.solver.SolverInterface
:param loss: The non-aggregating loss function to apply to the
computed residual against zero.
:type loss: torch.nn.Module
:return: The non-aggregated loss tensor.
:rtype: ~pina.label_tensor.LabelTensor
:Example:
>>> residuals = condition.evaluate(
... {"input": input_samples}, solver, loss
... )
>>> # residuals is a non-reduced tensor of shape (n_samples, ...)
"""
samples = batch["input"].requires_grad_(True)
residual = self.equation.residual(
samples, solver.forward(samples), solver._params
)
# assert False
return residual