Skip to content

Commit 31f59f9

Browse files
feat(solvers): add Universal Solvers
1 parent 7c108c3 commit 31f59f9

File tree

2 files changed

+312
-5
lines changed

2 files changed

+312
-5
lines changed

neurodiffeq/conditions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,11 @@ def parameterize(self, output_tensor, t):
262262
:rtype: `torch.Tensor`
263263
"""
264264
if self.u_0_prime is None:
265+
if isinstance(self.u_0, list):
266+
parameterized = torch.zeros_like(output_tensor)
267+
for i in range(len(self.u_0)):
268+
parameterized[:, i] = (self.u_0[i] + (1 - torch.exp(-t + self.t_0)) * output_tensor[:, i].view(-1, 1))[:, 0]
269+
return parameterized
265270
return self.u_0 + (1 - torch.exp(-t + self.t_0)) * output_tensor
266271
else:
267272
return self.u_0 + (t - self.t_0) * self.u_0_prime + ((1 - torch.exp(-t + self.t_0)) ** 2) * output_tensor

neurodiffeq/solvers.py

Lines changed: 307 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .generators import Generator2D
2121
from .generators import GeneratorND
2222
from .function_basis import RealSphericalHarmonics
23-
from .conditions import BaseCondition
23+
from .conditions import BaseCondition, NoCondition
2424
from .neurodiffeq import safe_diff as diff
2525
from .losses import _losses
2626

@@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
113113
def __init__(self, diff_eqs, conditions,
114114
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None,
115115
optimizer=None, loss_fn=None, n_batches_train=1, n_batches_valid=4,
116-
metrics=None, n_input_units=None, n_output_units=None,
116+
metrics=None, n_input_units=None, n_output_units=None, system_parameters=None,
117117
# deprecated arguments are listed below
118118
shuffle=None, batch_size=None):
119119
# deprecate argument `shuffle`
@@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
130130
)
131131

132132
self.diff_eqs = diff_eqs
133+
self.system_parameters = {}
134+
if system_parameters is not None:
135+
self.system_parameters = system_parameters
133136
self.conditions = conditions
134137
self.n_funcs = len(conditions)
135138
if nets is None:
@@ -376,7 +379,7 @@ def closure(zero_grad=True):
376379
for name in self.metrics_fn:
377380
value = self.metrics_fn[name](*funcs, *batch).item()
378381
metric_values[name] += value
379-
residuals = self.diff_eqs(*funcs, *batch)
382+
residuals = self.diff_eqs(*funcs, *batch, **self.system_parameters)
380383
residuals = torch.cat(residuals, dim=1)
381384
try:
382385
loss = self.loss_fn(residuals, funcs, batch) + self.additional_loss(residuals, funcs, batch)
@@ -1105,7 +1108,7 @@ class Solver1D(BaseSolver):
11051108

11061109
def __init__(self, ode_system, conditions, t_min=None, t_max=None,
11071110
nets=None, train_generator=None, valid_generator=None, analytic_solutions=None, optimizer=None,
1108-
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1,
1111+
loss_fn=None, n_batches_train=1, n_batches_valid=4, metrics=None, n_output_units=1, system_parameters=None,
11091112
# deprecated arguments are listed below
11101113
batch_size=None, shuffle=None):
11111114

@@ -1136,6 +1139,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
11361139
metrics=metrics,
11371140
n_input_units=1,
11381141
n_output_units=n_output_units,
1142+
system_parameters=system_parameters,
11391143
shuffle=shuffle,
11401144
batch_size=batch_size,
11411145
)
@@ -1164,11 +1168,12 @@ def get_solution(self, copy=True, best=True):
11641168
:rtype: BaseSolution
11651169
"""
11661170
nets = self.best_nets if best else self.nets
1171+
print(nets)
11671172
conditions = self.conditions
11681173
if copy:
11691174
nets = deepcopy(nets)
11701175
conditions = deepcopy(conditions)
1171-
1176+
print(nets)
11721177
return Solution1D(nets, conditions)
11731178

11741179
def _get_internal_variables(self):
@@ -1590,3 +1595,300 @@ def _get_internal_variables(self):
15901595
'xy_max': self.xy_max,
15911596
})
15921597
return available_variables
1598+
1599+
class _SingleSolver1D(GenericSolver):
1600+
1601+
class Head(nn.Module):
1602+
def __init__(self, u_0, base, n_input, n_output=1):
1603+
super().__init__()
1604+
self.u_0 = u_0
1605+
self.base = base
1606+
self.last_layer = nn.Linear(n_input, n_output)
1607+
1608+
def forward(self, x):
1609+
x = self.base(x)
1610+
x = self.last_layer(x)
1611+
return x
1612+
1613+
def __init__(self, bases, HeadClass, initial_conditions, n_last_layer_head, diff_eqs,
1614+
system_parameters=[{}],
1615+
optimizer=torch.optim.Adam, optimizer_args=None, optimizer_kwargs={"lr":1e-3},
1616+
train_generator=None, valid_generator=None, n_batches_train=1, n_batches_valid=4,
1617+
loss_fn=None, metrics=None, is_system=False):
1618+
1619+
if train_generator is None or valid_generator is None:
1620+
raise Exception(f"Train and Valid Generator cannot be None")
1621+
1622+
self.num = len(initial_conditions)
1623+
self.bases = bases
1624+
if HeadClass is None:
1625+
if is_system:
1626+
self.head = [self.Head(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
1627+
else:
1628+
self.head = [self.Head(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]
1629+
else:
1630+
if is_system:
1631+
self.head = [HeadClass(initial_conditions[i], self.bases[i], n_last_layer_head) for i in range(self.num)]
1632+
else:
1633+
self.head = [HeadClass(torch.Tensor(initial_conditions).view(1, -1), self.bases, n_last_layer_head, len(initial_conditions))]
1634+
1635+
self.optimizer_args = optimizer_args or ()
1636+
self.optimizer_kwargs = optimizer_kwargs or {}
1637+
1638+
if isinstance(optimizer, torch.optim.Optimizer):
1639+
self.optimizer = optimizer
1640+
elif issubclass(optimizer, torch.optim.Optimizer):
1641+
params = chain.from_iterable(n.parameters() for n in self.head)
1642+
self.optimizer = optimizer(params, *self.optimizer_args, **self.optimizer_kwargs)
1643+
else:
1644+
raise TypeError(f"Unknown optimizer instance/type {self.optimizer}")
1645+
1646+
super().__init__(
1647+
diff_eqs=diff_eqs,
1648+
conditions=[NoCondition()]*self.num,
1649+
train_generator=train_generator,
1650+
valid_generator=valid_generator,
1651+
nets=self.head,
1652+
system_parameters=system_parameters,
1653+
optimizer=self.optimizer,
1654+
n_batches_train=n_batches_train,
1655+
n_batches_valid=n_batches_valid,
1656+
loss_fn=loss_fn,
1657+
metrics=metrics
1658+
)
1659+
1660+
def additional_loss(self, residuals, funcs, coords):
1661+
1662+
loss = 0
1663+
for i in range(len(self.nets)):
1664+
out = self.nets[i](torch.zeros((1,1)))
1665+
loss += ((self.nets[i].u_0 - out)**2).mean()
1666+
return loss
1667+
1668+
1669+
class UniversalSolver1D(ABC):
1670+
r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)
1671+
1672+
:param ode_system:
1673+
The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
1674+
both the input and output must have shape (n_samples, 1).
1675+
:type ode_system: callable
1676+
"""
1677+
1678+
class Base(nn.Module):
1679+
def __init__(self):
1680+
super().__init__()
1681+
self.linear_1 = nn.Linear(1, 10)
1682+
self.linear_2 = nn.Linear(10, 10)
1683+
self.linear_3 = nn.Linear(10, 10)
1684+
1685+
def forward(self, x):
1686+
x = self.linear_1(x)
1687+
x = torch.tanh(x)
1688+
x = self.linear_2(x)
1689+
x = torch.tanh(x)
1690+
x = self.linear_3(x)
1691+
x = torch.tanh(x)
1692+
return x
1693+
1694+
def __init__(self, diff_eqs, is_system = True):
1695+
1696+
self.diff_eqs = diff_eqs
1697+
self.is_system = is_system
1698+
1699+
self.t_min = None
1700+
self.t_max = None
1701+
self.train_generator = None
1702+
self.valid_generator = None
1703+
1704+
def build(self,u_0s=None,
1705+
system_parameters=[{}],
1706+
BaseClass=Base,
1707+
HeadClass=None,
1708+
n_last_layer_head=10,
1709+
build_source=False,
1710+
optimizer=torch.optim.Adam,
1711+
optimizer_args=None, optimizer_kwargs={"lr":1e-3},
1712+
t_min=None,
1713+
t_max=None,
1714+
train_generator=None,
1715+
valid_generator=None,
1716+
n_batches_train=1,
1717+
n_batches_valid=4,
1718+
loss_fn=None,
1719+
metrics=None):
1720+
1721+
r"""
1722+
:param system_parameters:
1723+
List of dictionaries of parameters for which the solver will be trained
1724+
:type system_parameters: list[dict]
1725+
:param BaseClass:
1726+
Neural network class for base networks
1727+
:type nets: torch.nn.Module
1728+
:param n_last_layer_head:
1729+
Number of neurons in the last layer for each network
1730+
:type n_last_layer_head: int
1731+
:param build_source:
1732+
Boolean value for training the base networks or freezing their weights
1733+
:type build_source: bool
1734+
:param optimizer:
1735+
Optimizer to be used for training.
1736+
Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1737+
:type optimizer: ``torch.nn.optim.Optimizer``, optional
1738+
:param t_min:
1739+
Lower bound of input (start time).
1740+
Ignored if ``train_generator`` and ``valid_generator`` are both set.
1741+
:type t_min: float, optional
1742+
:param t_max:
1743+
Upper bound of input (start time).
1744+
Ignored if ``train_generator`` and ``valid_generator`` are both set.
1745+
:type t_max: float, optional
1746+
:param train_generator:
1747+
Generator for sampling training points,
1748+
which must provide a ``.get_examples()`` method and a ``.size`` field.
1749+
``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1750+
:type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
1751+
:param valid_generator:
1752+
Generator for sampling validation points,
1753+
which must provide a ``.get_examples()`` method and a ``.size`` field.
1754+
``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1755+
:type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
1756+
:param n_batches_train:
1757+
Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1758+
Defaults to 1.
1759+
:type n_batches_train: int, optional
1760+
:param n_batches_valid:
1761+
Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
1762+
Defaults to 4.
1763+
:type n_batches_valid: int, optional
1764+
:param loss_fn:
1765+
The loss function used for training.
1766+
1767+
- If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1768+
- If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1769+
- If any other callable, it must map
1770+
A) a residual tensor (shape `(n_points, n_equations)`),
1771+
B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1772+
C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1773+
to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1774+
so that backpropagation can be performed.
1775+
1776+
:type loss_fn:
1777+
str or `torch.nn.moduesl.loss._Loss` or callable
1778+
:param metrics:
1779+
Additional metrics to be logged (besides loss). ``metrics`` should be a dict where
1780+
1781+
- Keys are metric names (e.g. 'analytic_mse');
1782+
- Values are functions (callables) that computes the metric value.
1783+
These functions must accept the same input as the differential equation ``ode_system``.
1784+
1785+
:type metrics: dict[str, callable], optional
1786+
"""
1787+
1788+
self.u_0s = u_0s
1789+
self.system_parameters = system_parameters
1790+
self.n_last_layer_head = n_last_layer_head
1791+
1792+
if t_min is not None:
1793+
self.t_min = t_min
1794+
if t_max is not None:
1795+
self.t_max = t_max
1796+
1797+
if self.t_min is not None and self.t_max is not None:
1798+
self.train_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')
1799+
self.valid_generator = Generator1D(32, t_min=self.t_min, t_max=self.t_max, method='equally-spaced-noisy')
1800+
1801+
if train_generator is not None:
1802+
self.train_generator = train_generator
1803+
if valid_generator is not None:
1804+
self.valid_generator = valid_generator
1805+
1806+
if self.u_0s is None:
1807+
raise Exception("ICs must be specified")
1808+
if self.train_generator is None or self.valid_generator is None:
1809+
raise Exception(f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
1810+
or provide the generators as arguments")
1811+
1812+
self.optimizer = optimizer
1813+
self.optimizer_args = optimizer_args or ()
1814+
self.optimizer_kwargs = optimizer_kwargs or {}
1815+
1816+
if build_source:
1817+
if self.is_system:
1818+
self.bases = [BaseClass() for _ in range(len(u_0s[0]))]
1819+
else:
1820+
self.bases = BaseClass()
1821+
1822+
self.solvers_base = [_SingleSolver1D(
1823+
bases=self.bases,
1824+
HeadClass=HeadClass,
1825+
initial_conditions=self.u_0s[i],
1826+
n_last_layer_head=n_last_layer_head,
1827+
diff_eqs=self.diff_eqs,
1828+
train_generator=self.train_generator,
1829+
valid_generator=self.valid_generator,
1830+
system_parameters=self.system_parameters[p],
1831+
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
1832+
n_batches_train=n_batches_train,
1833+
n_batches_valid=n_batches_valid,
1834+
loss_fn=loss_fn,
1835+
metrics=metrics,
1836+
is_system=self.is_system
1837+
) for i in range(len(u_0s)) for p in range(len(self.system_parameters))]
1838+
else:
1839+
self.solvers_head = [_SingleSolver1D(
1840+
bases=self.bases,
1841+
HeadClass=HeadClass,
1842+
initial_conditions=self.u_0s[i],
1843+
n_last_layer_head=self.n_last_layer_head,
1844+
diff_eqs=self.diff_eqs,
1845+
train_generator=self.train_generator,
1846+
valid_generator=self.valid_generator,
1847+
system_parameters=self.system_parameters[p],
1848+
optimizer=optimizer,optimizer_args=optimizer_args, optimizer_kwargs=optimizer_kwargs,
1849+
n_batches_train=n_batches_train,
1850+
n_batches_valid=n_batches_valid,
1851+
loss_fn=loss_fn,
1852+
metrics=metrics,
1853+
is_system=self.is_system
1854+
) for i in range(len(self.u_0s)) for p in range(len(self.system_parameters))]
1855+
1856+
1857+
def fit(self, epochs=10, freeze_source=True):
1858+
r"""
1859+
:param epochs:
1860+
Number of epochs for training
1861+
:type epochs: int
1862+
:param freeze_source:
1863+
Boolean value indicating whether to freeze the base networks or not
1864+
:type freeze_source: bool
1865+
"""
1866+
1867+
if not freeze_source:
1868+
for i in range(len(self.solvers_base)):
1869+
self.solvers_base[i].fit(max_epochs=epochs)
1870+
else:
1871+
if self.is_system:
1872+
for net in self.bases:
1873+
for param in net.parameters():
1874+
param.requires_grad = False
1875+
else:
1876+
for param in self.bases.parameters():
1877+
param.requires_grad = False
1878+
for i in range(len(self.solvers_head)):
1879+
self.solvers_head[i].fit(max_epochs=epochs)
1880+
1881+
1882+
def get_solution(self, base=False):
1883+
r"""
1884+
:param base:
1885+
Boolean value indicating whether to get solutions for those conditions for which the base
1886+
was trained or solutions for those conditions for which only the last layer was trained
1887+
:type base: bool
1888+
:rtype: list[BaseSolution]
1889+
"""
1890+
1891+
if base:
1892+
return [self.solvers_base[i].get_solution() for i in range(len(self.solvers_base))]
1893+
else:
1894+
return [self.solvers_head[i].get_solution() for i in range(len(self.solvers_head))]

0 commit comments

Comments
 (0)