20
20
from .generators import Generator2D
21
21
from .generators import GeneratorND
22
22
from .function_basis import RealSphericalHarmonics
23
- from .conditions import BaseCondition
23
+ from .conditions import BaseCondition , NoCondition
24
24
from .neurodiffeq import safe_diff as diff
25
25
from .losses import _losses
26
26
@@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
113
113
def __init__ (self , diff_eqs , conditions ,
114
114
nets = None , train_generator = None , valid_generator = None , analytic_solutions = None ,
115
115
optimizer = None , loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 ,
116
- metrics = None , n_input_units = None , n_output_units = None ,
116
+ metrics = None , n_input_units = None , n_output_units = None , system_parameters = None ,
117
117
# deprecated arguments are listed below
118
118
shuffle = None , batch_size = None ):
119
119
# deprecate argument `shuffle`
@@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
130
130
)
131
131
132
132
self .diff_eqs = diff_eqs
133
+ self .system_parameters = {}
134
+ if system_parameters is not None :
135
+ self .system_parameters = system_parameters
133
136
self .conditions = conditions
134
137
self .n_funcs = len (conditions )
135
138
if nets is None :
@@ -370,7 +373,7 @@ def closure(zero_grad=True):
370
373
for name in self .metrics_fn :
371
374
value = self .metrics_fn [name ](* funcs , * batch ).item ()
372
375
metric_values [name ] += value
373
- residuals = self .diff_eqs (* funcs , * batch )
376
+ residuals = self .diff_eqs (* funcs , * batch , ** self . system_parameters )
374
377
residuals = torch .cat (residuals , dim = 1 )
375
378
try :
376
379
loss = self .loss_fn (residuals , funcs , batch ) + self .additional_loss (residuals , funcs , batch )
@@ -1095,7 +1098,7 @@ class Solver1D(BaseSolver):
1095
1098
1096
1099
def __init__ (self , ode_system , conditions , t_min = None , t_max = None ,
1097
1100
nets = None , train_generator = None , valid_generator = None , analytic_solutions = None , optimizer = None ,
1098
- loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 , metrics = None , n_output_units = 1 ,
1101
+ loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 , metrics = None , n_output_units = 1 , system_parameters = None ,
1099
1102
# deprecated arguments are listed below
1100
1103
batch_size = None , shuffle = None ):
1101
1104
@@ -1126,6 +1129,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
1126
1129
metrics = metrics ,
1127
1130
n_input_units = 1 ,
1128
1131
n_output_units = n_output_units ,
1132
+ system_parameters = system_parameters ,
1129
1133
shuffle = shuffle ,
1130
1134
batch_size = batch_size ,
1131
1135
)
@@ -1154,11 +1158,12 @@ def get_solution(self, copy=True, best=True):
1154
1158
:rtype: BaseSolution
1155
1159
"""
1156
1160
nets = self .best_nets if best else self .nets
1161
+ print (nets )
1157
1162
conditions = self .conditions
1158
1163
if copy :
1159
1164
nets = deepcopy (nets )
1160
1165
conditions = deepcopy (conditions )
1161
-
1166
+ print ( nets )
1162
1167
return Solution1D (nets , conditions )
1163
1168
1164
1169
def _get_internal_variables (self ):
@@ -1563,3 +1568,300 @@ def _get_internal_variables(self):
1563
1568
'xy_max' : self .xy_max ,
1564
1569
})
1565
1570
return available_variables
1571
+
1572
+ class _SingleSolver1D (GenericSolver ):
1573
+
1574
+ class Head (nn .Module ):
1575
+ def __init__ (self , u_0 , base , n_input , n_output = 1 ):
1576
+ super ().__init__ ()
1577
+ self .u_0 = u_0
1578
+ self .base = base
1579
+ self .last_layer = nn .Linear (n_input , n_output )
1580
+
1581
+ def forward (self , x ):
1582
+ x = self .base (x )
1583
+ x = self .last_layer (x )
1584
+ return x
1585
+
1586
+ def __init__ (self , bases , HeadClass , initial_conditions , n_last_layer_head , diff_eqs ,
1587
+ system_parameters = [{}],
1588
+ optimizer = torch .optim .Adam , optimizer_args = None , optimizer_kwargs = {"lr" :1e-3 },
1589
+ train_generator = None , valid_generator = None , n_batches_train = 1 , n_batches_valid = 4 ,
1590
+ loss_fn = None , metrics = None , is_system = False ):
1591
+
1592
+ if train_generator is None or valid_generator is None :
1593
+ raise Exception (f"Train and Valid Generator cannot be None" )
1594
+
1595
+ self .num = len (initial_conditions )
1596
+ self .bases = bases
1597
+ if HeadClass is None :
1598
+ if is_system :
1599
+ self .head = [self .Head (initial_conditions [i ], self .bases [i ], n_last_layer_head ) for i in range (self .num )]
1600
+ else :
1601
+ self .head = [self .Head (torch .Tensor (initial_conditions ).view (1 , - 1 ), self .bases , n_last_layer_head , len (initial_conditions ))]
1602
+ else :
1603
+ if is_system :
1604
+ self .head = [HeadClass (initial_conditions [i ], self .bases [i ], n_last_layer_head ) for i in range (self .num )]
1605
+ else :
1606
+ self .head = [HeadClass (torch .Tensor (initial_conditions ).view (1 , - 1 ), self .bases , n_last_layer_head , len (initial_conditions ))]
1607
+
1608
+ self .optimizer_args = optimizer_args or ()
1609
+ self .optimizer_kwargs = optimizer_kwargs or {}
1610
+
1611
+ if isinstance (optimizer , torch .optim .Optimizer ):
1612
+ self .optimizer = optimizer
1613
+ elif issubclass (optimizer , torch .optim .Optimizer ):
1614
+ params = chain .from_iterable (n .parameters () for n in self .head )
1615
+ self .optimizer = optimizer (params , * self .optimizer_args , ** self .optimizer_kwargs )
1616
+ else :
1617
+ raise TypeError (f"Unknown optimizer instance/type { self .optimizer } " )
1618
+
1619
+ super ().__init__ (
1620
+ diff_eqs = diff_eqs ,
1621
+ conditions = [NoCondition ()]* self .num ,
1622
+ train_generator = train_generator ,
1623
+ valid_generator = valid_generator ,
1624
+ nets = self .head ,
1625
+ system_parameters = system_parameters ,
1626
+ optimizer = self .optimizer ,
1627
+ n_batches_train = n_batches_train ,
1628
+ n_batches_valid = n_batches_valid ,
1629
+ loss_fn = loss_fn ,
1630
+ metrics = metrics
1631
+ )
1632
+
1633
+ def additional_loss (self , residuals , funcs , coords ):
1634
+
1635
+ loss = 0
1636
+ for i in range (len (self .nets )):
1637
+ out = self .nets [i ](torch .zeros ((1 ,1 )))
1638
+ loss += ((self .nets [i ].u_0 - out )** 2 ).mean ()
1639
+ return loss
1640
+
1641
+
1642
+ class UniversalSolver1D (ABC ):
1643
+ r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)
1644
+
1645
+ :param ode_system:
1646
+ The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
1647
+ both the input and output must have shape (n_samples, 1).
1648
+ :type ode_system: callable
1649
+ """
1650
+
1651
+ class Base (nn .Module ):
1652
+ def __init__ (self ):
1653
+ super ().__init__ ()
1654
+ self .linear_1 = nn .Linear (1 , 10 )
1655
+ self .linear_2 = nn .Linear (10 , 10 )
1656
+ self .linear_3 = nn .Linear (10 , 10 )
1657
+
1658
+ def forward (self , x ):
1659
+ x = self .linear_1 (x )
1660
+ x = torch .tanh (x )
1661
+ x = self .linear_2 (x )
1662
+ x = torch .tanh (x )
1663
+ x = self .linear_3 (x )
1664
+ x = torch .tanh (x )
1665
+ return x
1666
+
1667
+ def __init__ (self , diff_eqs , is_system = True ):
1668
+
1669
+ self .diff_eqs = diff_eqs
1670
+ self .is_system = is_system
1671
+
1672
+ self .t_min = None
1673
+ self .t_max = None
1674
+ self .train_generator = None
1675
+ self .valid_generator = None
1676
+
1677
+ def build (self ,u_0s = None ,
1678
+ system_parameters = [{}],
1679
+ BaseClass = Base ,
1680
+ HeadClass = None ,
1681
+ n_last_layer_head = 10 ,
1682
+ build_source = False ,
1683
+ optimizer = torch .optim .Adam ,
1684
+ optimizer_args = None , optimizer_kwargs = {"lr" :1e-3 },
1685
+ t_min = None ,
1686
+ t_max = None ,
1687
+ train_generator = None ,
1688
+ valid_generator = None ,
1689
+ n_batches_train = 1 ,
1690
+ n_batches_valid = 4 ,
1691
+ loss_fn = None ,
1692
+ metrics = None ):
1693
+
1694
+ r"""
1695
+ :param system_parameters:
1696
+ List of dictionaries of parameters for which the solver will be trained
1697
+ :type system_parameters: list[dict]
1698
+ :param BaseClass:
1699
+ Neural network class for base networks
1700
+ :type nets: torch.nn.Module
1701
+ :param n_last_layer_head:
1702
+ Number of neurons in the last layer for each network
1703
+ :type n_last_layer_head: int
1704
+ :param build_source:
1705
+ Boolean value for training the base networks or freezing their weights
1706
+ :type build_source: bool
1707
+ :param optimizer:
1708
+ Optimizer to be used for training.
1709
+ Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1710
+ :type optimizer: ``torch.nn.optim.Optimizer``, optional
1711
+ :param t_min:
1712
+ Lower bound of input (start time).
1713
+ Ignored if ``train_generator`` and ``valid_generator`` are both set.
1714
+ :type t_min: float, optional
1715
+ :param t_max:
1716
+ Upper bound of input (start time).
1717
+ Ignored if ``train_generator`` and ``valid_generator`` are both set.
1718
+ :type t_max: float, optional
1719
+ :param train_generator:
1720
+ Generator for sampling training points,
1721
+ which must provide a ``.get_examples()`` method and a ``.size`` field.
1722
+ ``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1723
+ :type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
1724
+ :param valid_generator:
1725
+ Generator for sampling validation points,
1726
+ which must provide a ``.get_examples()`` method and a ``.size`` field.
1727
+ ``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1728
+ :type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
1729
+ :param n_batches_train:
1730
+ Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1731
+ Defaults to 1.
1732
+ :type n_batches_train: int, optional
1733
+ :param n_batches_valid:
1734
+ Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
1735
+ Defaults to 4.
1736
+ :type n_batches_valid: int, optional
1737
+ :param loss_fn:
1738
+ The loss function used for training.
1739
+
1740
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1741
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1742
+ - If any other callable, it must map
1743
+ A) a residual tensor (shape `(n_points, n_equations)`),
1744
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1745
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1746
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1747
+ so that backpropagation can be performed.
1748
+
1749
+ :type loss_fn:
1750
+ str or `torch.nn.moduesl.loss._Loss` or callable
1751
+ :param metrics:
1752
+ Additional metrics to be logged (besides loss). ``metrics`` should be a dict where
1753
+
1754
+ - Keys are metric names (e.g. 'analytic_mse');
1755
+ - Values are functions (callables) that computes the metric value.
1756
+ These functions must accept the same input as the differential equation ``ode_system``.
1757
+
1758
+ :type metrics: dict[str, callable], optional
1759
+ """
1760
+
1761
+ self .u_0s = u_0s
1762
+ self .system_parameters = system_parameters
1763
+ self .n_last_layer_head = n_last_layer_head
1764
+
1765
+ if t_min is not None :
1766
+ self .t_min = t_min
1767
+ if t_max is not None :
1768
+ self .t_max = t_max
1769
+
1770
+ if self .t_min is not None and self .t_max is not None :
1771
+ self .train_generator = Generator1D (32 , t_min = self .t_min , t_max = self .t_max , method = 'equally-spaced-noisy' )
1772
+ self .valid_generator = Generator1D (32 , t_min = self .t_min , t_max = self .t_max , method = 'equally-spaced-noisy' )
1773
+
1774
+ if train_generator is not None :
1775
+ self .train_generator = train_generator
1776
+ if valid_generator is not None :
1777
+ self .valid_generator = valid_generator
1778
+
1779
+ if self .u_0s is None :
1780
+ raise Exception ("ICs must be specified" )
1781
+ if self .train_generator is None or self .valid_generator is None :
1782
+ raise Exception (f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
1783
+ or provide the generators as arguments" )
1784
+
1785
+ self .optimizer = optimizer
1786
+ self .optimizer_args = optimizer_args or ()
1787
+ self .optimizer_kwargs = optimizer_kwargs or {}
1788
+
1789
+ if build_source :
1790
+ if self .is_system :
1791
+ self .bases = [BaseClass () for _ in range (len (u_0s [0 ]))]
1792
+ else :
1793
+ self .bases = BaseClass ()
1794
+
1795
+ self .solvers_base = [_SingleSolver1D (
1796
+ bases = self .bases ,
1797
+ HeadClass = HeadClass ,
1798
+ initial_conditions = self .u_0s [i ],
1799
+ n_last_layer_head = n_last_layer_head ,
1800
+ diff_eqs = self .diff_eqs ,
1801
+ train_generator = self .train_generator ,
1802
+ valid_generator = self .valid_generator ,
1803
+ system_parameters = self .system_parameters [p ],
1804
+ optimizer = optimizer ,optimizer_args = optimizer_args , optimizer_kwargs = optimizer_kwargs ,
1805
+ n_batches_train = n_batches_train ,
1806
+ n_batches_valid = n_batches_valid ,
1807
+ loss_fn = loss_fn ,
1808
+ metrics = metrics ,
1809
+ is_system = self .is_system
1810
+ ) for i in range (len (u_0s )) for p in range (len (self .system_parameters ))]
1811
+ else :
1812
+ self .solvers_head = [_SingleSolver1D (
1813
+ bases = self .bases ,
1814
+ HeadClass = HeadClass ,
1815
+ initial_conditions = self .u_0s [i ],
1816
+ n_last_layer_head = self .n_last_layer_head ,
1817
+ diff_eqs = self .diff_eqs ,
1818
+ train_generator = self .train_generator ,
1819
+ valid_generator = self .valid_generator ,
1820
+ system_parameters = self .system_parameters [p ],
1821
+ optimizer = optimizer ,optimizer_args = optimizer_args , optimizer_kwargs = optimizer_kwargs ,
1822
+ n_batches_train = n_batches_train ,
1823
+ n_batches_valid = n_batches_valid ,
1824
+ loss_fn = loss_fn ,
1825
+ metrics = metrics ,
1826
+ is_system = self .is_system
1827
+ ) for i in range (len (self .u_0s )) for p in range (len (self .system_parameters ))]
1828
+
1829
+
1830
+ def fit (self , epochs = 10 , freeze_source = True ):
1831
+ r"""
1832
+ :param epochs:
1833
+ Number of epochs for training
1834
+ :type epochs: int
1835
+ :param freeze_source:
1836
+ Boolean value indicating whether to freeze the base networks or not
1837
+ :type freeze_source: bool
1838
+ """
1839
+
1840
+ if not freeze_source :
1841
+ for i in range (len (self .solvers_base )):
1842
+ self .solvers_base [i ].fit (max_epochs = epochs )
1843
+ else :
1844
+ if self .is_system :
1845
+ for net in self .bases :
1846
+ for param in net .parameters ():
1847
+ param .requires_grad = False
1848
+ else :
1849
+ for param in self .bases .parameters ():
1850
+ param .requires_grad = False
1851
+ for i in range (len (self .solvers_head )):
1852
+ self .solvers_head [i ].fit (max_epochs = epochs )
1853
+
1854
+
1855
+ def get_solution (self , base = False ):
1856
+ r"""
1857
+ :param base:
1858
+ Boolean value indicating whether to get solutions for those conditions for which the base
1859
+ was trained or solutions for those conditions for which only the last layer was trained
1860
+ :type base: bool
1861
+ :rtype: list[BaseSolution]
1862
+ """
1863
+
1864
+ if base :
1865
+ return [self .solvers_base [i ].get_solution () for i in range (len (self .solvers_base ))]
1866
+ else :
1867
+ return [self .solvers_head [i ].get_solution () for i in range (len (self .solvers_head ))]
0 commit comments