Skip to content

Commit fd43d7f

Browse files
feat(solvers): add Universal Solvers
1 parent a3df9fd commit fd43d7f

File tree

2 files changed

+312
-5
lines changed

2 files changed

+312
-5
lines changed

neurodiffeq/conditions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ def parameterize(self, output_tensor, t):
202202
:rtype: `torch.Tensor`
203203
"""
204204
if self.u_0_prime is None:
205+
if isinstance(self.u_0, list):
206+
parameterized = torch.zeros_like(output_tensor)
207+
for i in range(len(self.u_0)):
208+
parameterized[:, i] = (self.u_0[i] + (1 - torch.exp(-t + self.t_0)) * output_tensor[:, i].view(-1, 1))[:, 0]
209+
return parameterized
205210
return self.u_0 + (1 - torch.exp(-t + self.t_0)) * output_tensor
206211
else:
207212
return self.u_0 + (t - self.t_0) * self.u_0_prime + ((1 - torch.exp(-t + self.t_0)) ** 2) * output_tensor

neurodiffeq/solvers.py

Lines changed: 307 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .generators import Generator2D
2121
from .generators import GeneratorND
2222
from .function_basis import RealSphericalHarmonics
23-
from .conditions import BaseCondition
23+
from .conditions import BaseCondition, NoCondition
2424
from .neurodiffeq import safe_diff as diff
2525
from .losses import _losses
2626

@@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
113113
def __init__(self, diff_eqs, conditions,
114114
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
115115
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4,
116-
metrics=None, n_input_units=None, n_output_units=None,
116+
metrics=None, n_input_units=None, n_output_units=None, system_parameters=None,
117117
# deprecated arguments are listed below
118118
shuffle=None, batch_size=None):
119119
# deprecate argument `shuffle`
@@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
130130
)
131131

132132
self.diff_eqs = diff_eqs
133+
self.system_parameters = {}
134+
if system_parameters is not None:
135+
self.system_parameters = system_parameters
133136
self.conditions = conditions
134137
self.n_funcs = len(conditions)
135138
if nets is None:
@@ -370,7 +373,7 @@ def closure(zero_grad=True):
370373
for name in self.metrics_fn:
371374
value = self.metrics_fn[name](*funcs, *batch).item()
372375
metric_values[name] += value
373-
residuals = self.diff_eqs(*funcs, *batch)
376+
residuals = self.diff_eqs(*funcs, *batch, **self.system_parameters)
374377
residuals = torch.cat(residuals, dim=1)
375378
try:
376379
loss = self.loss_fn(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
@@ -1095,7 +1098,7 @@ class Solver1D(BaseSolver):
10951098

10961099
def __init__(self, ode_system, conditions, t_min=None, t_max=None,
10971100
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
1098-
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
1101+
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1, system_parameters=None,
10991102
# deprecated arguments are listed below
11001103
batch_size=None, shuffle=None):
11011104

@@ -1126,6 +1129,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
11261129
metrics=metrics,
11271130
n_input_units=1,
11281131
n_output_units=n_output_units,
1132+
system_parameters=system_parameters,
11291133
shuffle=shuffle,
11301134
batch_size=batch_size,
11311135
)
@@ -1154,11 +1158,12 @@ def get_solution(self, copy=True, best=True):
11541158
:rtype: BaseSolution
11551159
"""
11561160
nets = self.best_nets if best else self.nets
1161+
print(nets)
11571162
conditions = self.conditions
11581163
if copy:
11591164
nets = deepcopy(nets)
11601165
conditions = deepcopy(conditions)
1161-
1166+
print(nets)
11621167
return Solution1D(nets, conditions)
11631168

11641169
def _get_internal_variables(self):
@@ -1563,3 +1568,300 @@ def _get_internal_variables(self):
15631568
'xy_max': self.xy_max,
15641569
})
15651570
return available_variables
1571+
1572+
class _SingleSolver1D(GenericSolver):
1573+
1574+
class Head(nn.Module):
1575+
def __init__(self, u_0, base, n_input, n_output=1):
1576+
super().__init__()
1577+
self.u_0 = u_0
1578+
self.base = base
1579+
self.last_layer = nn.Linear(n_input, n_output)
1580+
1581+
def forward(self, x):
1582+
x = self.base(x)
1583+
x = self.last_layer(x)
1584+
return x
1585+
1586+
def __init__(self, bases, HeadClass, initial_conditions, n_last_layer_head, diff_eqs,
1587+
system_parameters=[{}],
1588+
optimizer=torch.optim.Adam, optimizer_args=None, optimizer_kwargs={"lr":1e-3},
1589+
train_generator=None, valid_generator=None, n_batches_train=1, n_batches_valid=4,
1590+
loss_fn=None, metrics=None, is_system=False):
1591+
1592+
if train_generator is None or valid_generator is None:
1593+
raise Exception(f"Train and Valid Generator cannot be None")
1594+
1595+
self.num = len(initial_conditions)
1596+
self.bases = bases
1597+
if HeadClass is None:
1598+
if is_system:
1599+
self.head = [self.Head(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
1600+
else:
1601+
self.head = [self.Head(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]
1602+
else:
1603+
if is_system:
1604+
self.head = [HeadClass(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
1605+
else:
1606+
self.head = [HeadClass(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]
1607+
1608+
self.optimizer_args = optimizer_args or ()
1609+
self.optimizer_kwargs = optimizer_kwargs or {}
1610+
1611+
if isinstance(optimizer, torch.optim.Optimizer):
1612+
self.optimizer = optimizer
1613+
elif issubclass(optimizer, torch.optim.Optimizer):
1614+
params = chain.from_iterable(n.parameters() for n in self.head)
1615+
self.optimizer = optimizer(params, *self.optimizer_args, **self.optimizer_kwargs)
1616+
else:
1617+
raise TypeError(f"Unknown optimizer instance/type {self.optimizer}")
1618+
1619+
super().__init__(
1620+
diff_eqs=diff_eqs,
1621+
conditions=[NoCondition()]*self.num,
1622+
train_generator=train_generator,
1623+
valid_generator=valid_generator,
1624+
nets=self.head,
1625+
system_parameters=system_parameters,
1626+
optimizer=self.optimizer,
1627+
n_batches_train=n_batches_train,
1628+
n_batches_valid=n_batches_valid,
1629+
loss_fn=loss_fn,
1630+
metrics=metrics
1631+
)
1632+
1633+
def additional_loss(self, residuals, funcs, coords):
1634+
1635+
loss = 0
1636+
for i in range(len(self.nets)):
1637+
out = self.nets[i](torch.zeros((1,1)))
1638+
loss += ((self.nets[i].u_0 - out)**2).mean()
1639+
return loss
1640+
1641+
1642+
class UniversalSolver1D(ABC):
1643+
r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)
1644+
1645+
:param ode_system:
1646+
The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
1647+
both the input and output must have shape (n_samples, 1).
1648+
:type ode_system: callable
1649+
"""
1650+
1651+
class Base(nn.Module):
1652+
def __init__(self):
1653+
super().__init__()
1654+
self.linear_1 = nn.Linear(1, 10)
1655+
self.linear_2 = nn.Linear(10, 10)
1656+
self.linear_3 = nn.Linear(10, 10)
1657+
1658+
def forward(self, x):
1659+
x = self.linear_1(x)
1660+
x = torch.tanh(x)
1661+
x = self.linear_2(x)
1662+
x = torch.tanh(x)
1663+
x = self.linear_3(x)
1664+
x = torch.tanh(x)
1665+
return x
1666+
1667+
def __init__(self, diff_eqs, is_system = True):
1668+
1669+
self.diff_eqs = diff_eqs
1670+
self.is_system = is_system
1671+
1672+
self.t_min = None
1673+
self.t_max = None
1674+
self.train_generator = None
1675+
self.valid_generator = None
1676+
1677+
def build(self,u_0s=None,
1678+
system_parameters=[{}],
1679+
BaseClass=Base,
1680+
HeadClass=None,
1681+
n_last_layer_head=10,
1682+
build_source=False,
1683+
optimizer=torch.optim.Adam,
1684+
optimizer_args=None, optimizer_kwargs={"lr":1e-3},
1685+
t_min=None,
1686+
t_max=None,
1687+
train_generator=None,
1688+
valid_generator=None,
1689+
n_batches_train=1,
1690+
n_batches_valid=4,
1691+
loss_fn=None,
1692+
metrics=None):
1693+
1694+
r"""
1695+
:param system_parameters:
1696+
List of dictionaries of parameters for which the solver will be trained
1697+
:type system_parameters: list[dict]
1698+
:param BaseClass:
1699+
Neural network class for base networks
1700+
:type nets: torch.nn.Module
1701+
:param n_last_layer_head:
1702+
Number of neurons in the last layer for each network
1703+
:type n_last_layer_head: int
1704+
:param build_source:
1705+
Boolean value for training the base networks or freezing their weights
1706+
:type build_source: bool
1707+
:param optimizer:
1708+
Optimizer to be used for training.
1709+
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1710+
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1711+
:param t_min:
1712+
Lower bound of input (start time).
1713+
Ignored if ``train_generator`` and ``valid_generator`` are both set.
1714+
:type t_min: float, optional
1715+
:param t_max:
1716+
Upper bound of input (start time).
1717+
Ignored if ``train_generator`` and ``valid_generator`` are both set.
1718+
:type t_max: float, optional
1719+
:param train_generator:
1720+
Generator for sampling training points,
1721+
which must provide a ``.get_examples()`` method and a ``.size`` field.
1722+
``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1723+
:type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
1724+
:param valid_generator:
1725+
Generator for sampling validation points,
1726+
which must provide a ``.get_examples()`` method and a ``.size`` field.
1727+
``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1728+
:type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
1729+
:param n_batches_train:
1730+
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1731+
Defaults to 1.
1732+
:type n_batches_train: int, optional
1733+
:param n_batches_valid:
1734+
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
1735+
Defaults to 4.
1736+
:type n_batches_valid: int, optional
1737+
:param loss_fn:
1738+
The loss function used for training.
1739+
1740+
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1741+
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1742+
- If any other callable, it must map
1743+
A) a residual tensor (shape `(n_points, n_equations)`),
1744+
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1745+
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1746+
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1747+
so that backpropagation can be performed.
1748+
1749+
:type loss_fn:
1750+
str or `torch.nn.moduesl.loss._Loss` or callable
1751+
:param metrics:
1752+
Additional metrics to be logged (besides loss). ``metrics`` should be a dict where
1753+
1754+
- Keys are metric names (e.g. 'analytic_mse');
1755+
- Values are functions (callables) that computes the metric value.
1756+
These functions must accept the same input as the differential equation ``ode_system``.
1757+
1758+
:type metrics: dict[str, callable], optional
1759+
"""
1760+
1761+
self.u_0s = u_0s
1762+
self.system_parameters = system_parameters
1763+
self.n_last_layer_head = n_last_layer_head
1764+
1765+
if t_min is not None:
1766+
self.t_min = t_min
1767+
if t_max is not None:
1768+
self.t_max = t_max
1769+
1770+
if self.t_min is not None and self.t_max is not None:
1771+
self.train_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')
1772+
self.valid_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')
1773+
1774+
if train_generator is not None:
1775+
self.train_generator = train_generator
1776+
if valid_generator is not None:
1777+
self.valid_generator = valid_generator
1778+
1779+
if self.u_0s is None:
1780+
raise Exception("ICs must be specified")
1781+
if self.train_generator is None or self.valid_generator is None:
1782+
raise Exception(f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
1783+
or provide the generators as arguments")
1784+
1785+
self.optimizer = optimizer
1786+
self.optimizer_args = optimizer_args or ()
1787+
self.optimizer_kwargs = optimizer_kwargs or {}
1788+
1789+
if build_source:
1790+
if self.is_system:
1791+
self.bases = [BaseClass() for _ in range(len(u_0s[0]))]
1792+
else:
1793+
self.bases = BaseClass()
1794+
1795+
self.solvers_base = [_SingleSolver1D(
1796+
bases=self.bases,
1797+
HeadClass=HeadClass,
1798+
initial_conditions=self.u_0s[i],
1799+
n_last_layer_head=n_last_layer_head,
1800+
diff_eqs=self.diff_eqs,
1801+
train_generator=self.train_generator,
1802+
valid_generator=self.valid_generator,
1803+
system_parameters=self.system_parameters[p],
1804+
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
1805+
n_batches_train=n_batches_train,
1806+
n_batches_valid=n_batches_valid,
1807+
loss_fn=loss_fn,
1808+
metrics=metrics,
1809+
is_system=self.is_system
1810+
) for i in range(len(u_0s)) for p in range(len(self.system_parameters))]
1811+
else:
1812+
self.solvers_head = [_SingleSolver1D(
1813+
bases=self.bases,
1814+
HeadClass=HeadClass,
1815+
initial_conditions=self.u_0s[i],
1816+
n_last_layer_head=self.n_last_layer_head,
1817+
diff_eqs=self.diff_eqs,
1818+
train_generator=self.train_generator,
1819+
valid_generator=self.valid_generator,
1820+
system_parameters=self.system_parameters[p],
1821+
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
1822+
n_batches_train=n_batches_train,
1823+
n_batches_valid=n_batches_valid,
1824+
loss_fn=loss_fn,
1825+
metrics=metrics,
1826+
is_system=self.is_system
1827+
) for i in range(len(self.u_0s)) for p in range(len(self.system_parameters))]
1828+
1829+
1830+
def fit(self, epochs=10, freeze_source=True):
1831+
r"""
1832+
:param epochs:
1833+
Number of epochs for training
1834+
:type epochs: int
1835+
:param freeze_source:
1836+
Boolean value indicating whether to freeze the base networks or not
1837+
:type freeze_source: bool
1838+
"""
1839+
1840+
if not freeze_source:
1841+
for i in range(len(self.solvers_base)):
1842+
self.solvers_base[i].fit(max_epochs=epochs)
1843+
else:
1844+
if self.is_system:
1845+
for net in self.bases:
1846+
for param in net.parameters():
1847+
param.requires_grad = False
1848+
else:
1849+
for param in self.bases.parameters():
1850+
param.requires_grad = False
1851+
for i in range(len(self.solvers_head)):
1852+
self.solvers_head[i].fit(max_epochs=epochs)
1853+
1854+
1855+
def get_solution(self, base=False):
1856+
r"""
1857+
:param base:
1858+
Boolean value indicating whether to get solutions for those conditions for which the base
1859+
was trained or solutions for those conditions for which only the last layer was trained
1860+
:type base: bool
1861+
:rtype: list[BaseSolution]
1862+
"""
1863+
1864+
if base:
1865+
return [self.solvers_base[i].get_solution() for i in range(len(self.solvers_base))]
1866+
else:
1867+
return [self.solvers_head[i].get_solution() for i in range(len(self.solvers_head))]

0 commit comments

Comments
 (0)