20
20
from .generators import Generator2D
21
21
from .generators import GeneratorND
22
22
from .function_basis import RealSphericalHarmonics
23
- from .conditions import BaseCondition
23
+ from .conditions import BaseCondition , NoCondition
24
24
from .neurodiffeq import safe_diff as diff
25
25
from .losses import _losses
26
26
@@ -113,7 +113,7 @@ class BaseSolver(ABC, PretrainedSolver):
113
113
def __init__ (self , diff_eqs , conditions ,
114
114
nets = None , train_generator = None , valid_generator = None , analytic_solutions = None ,
115
115
optimizer = None , loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 ,
116
- metrics = None , n_input_units = None , n_output_units = None ,
116
+ metrics = None , n_input_units = None , n_output_units = None , system_parameters = None ,
117
117
# deprecated arguments are listed below
118
118
shuffle = None , batch_size = None ):
119
119
# deprecate argument `shuffle`
@@ -130,6 +130,9 @@ def __init__(self, diff_eqs, conditions,
130
130
)
131
131
132
132
self .diff_eqs = diff_eqs
133
+ self .system_parameters = {}
134
+ if system_parameters is not None :
135
+ self .system_parameters = system_parameters
133
136
self .conditions = conditions
134
137
self .n_funcs = len (conditions )
135
138
if nets is None :
@@ -376,7 +379,7 @@ def closure(zero_grad=True):
376
379
for name in self .metrics_fn :
377
380
value = self .metrics_fn [name ](* funcs , * batch ).item ()
378
381
metric_values [name ] += value
379
- residuals = self .diff_eqs (* funcs , * batch )
382
+ residuals = self .diff_eqs (* funcs , * batch , ** self . system_parameters )
380
383
residuals = torch .cat (residuals , dim = 1 )
381
384
try :
382
385
loss = self .loss_fn (residuals , funcs , batch ) + self .additional_loss (residuals , funcs , batch )
@@ -1105,7 +1108,7 @@ class Solver1D(BaseSolver):
1105
1108
1106
1109
def __init__ (self , ode_system , conditions , t_min = None , t_max = None ,
1107
1110
nets = None , train_generator = None , valid_generator = None , analytic_solutions = None , optimizer = None ,
1108
- loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 , metrics = None , n_output_units = 1 ,
1111
+ loss_fn = None , n_batches_train = 1 , n_batches_valid = 4 , metrics = None , n_output_units = 1 , system_parameters = None ,
1109
1112
# deprecated arguments are listed below
1110
1113
batch_size = None , shuffle = None ):
1111
1114
@@ -1136,6 +1139,7 @@ def __init__(self, ode_system, conditions, t_min=None, t_max=None,
1136
1139
metrics = metrics ,
1137
1140
n_input_units = 1 ,
1138
1141
n_output_units = n_output_units ,
1142
+ system_parameters = system_parameters ,
1139
1143
shuffle = shuffle ,
1140
1144
batch_size = batch_size ,
1141
1145
)
@@ -1164,11 +1168,12 @@ def get_solution(self, copy=True, best=True):
1164
1168
:rtype: BaseSolution
1165
1169
"""
1166
1170
nets = self .best_nets if best else self .nets
1171
+ print (nets )
1167
1172
conditions = self .conditions
1168
1173
if copy :
1169
1174
nets = deepcopy (nets )
1170
1175
conditions = deepcopy (conditions )
1171
-
1176
+ print ( nets )
1172
1177
return Solution1D (nets , conditions )
1173
1178
1174
1179
def _get_internal_variables (self ):
@@ -1590,3 +1595,300 @@ def _get_internal_variables(self):
1590
1595
'xy_max' : self .xy_max ,
1591
1596
})
1592
1597
return available_variables
1598
+
1599
+ class _SingleSolver1D (GenericSolver ):
1600
+
1601
+ class Head (nn .Module ):
1602
+ def __init__ (self , u_0 , base , n_input , n_output = 1 ):
1603
+ super ().__init__ ()
1604
+ self .u_0 = u_0
1605
+ self .base = base
1606
+ self .last_layer = nn .Linear (n_input , n_output )
1607
+
1608
+ def forward (self , x ):
1609
+ x = self .base (x )
1610
+ x = self .last_layer (x )
1611
+ return x
1612
+
1613
+ def __init__ (self , bases , HeadClass , initial_conditions , n_last_layer_head , diff_eqs ,
1614
+ system_parameters = [{}],
1615
+ optimizer = torch .optim .Adam , optimizer_args = None , optimizer_kwargs = {"lr" :1e-3 },
1616
+ train_generator = None , valid_generator = None , n_batches_train = 1 , n_batches_valid = 4 ,
1617
+ loss_fn = None , metrics = None , is_system = False ):
1618
+
1619
+ if train_generator is None or valid_generator is None :
1620
+ raise Exception (f"Train and Valid Generator cannot be None" )
1621
+
1622
+ self .num = len (initial_conditions )
1623
+ self .bases = bases
1624
+ if HeadClass is None :
1625
+ if is_system :
1626
+ self .head = [self .Head (initial_conditions [i ], self .bases [i ], n_last_layer_head ) for i in range (self .num )]
1627
+ else :
1628
+ self .head = [self .Head (torch .Tensor (initial_conditions ).view (1 , - 1 ), self .bases , n_last_layer_head , len (initial_conditions ))]
1629
+ else :
1630
+ if is_system :
1631
+ self .head = [HeadClass (initial_conditions [i ], self .bases [i ], n_last_layer_head ) for i in range (self .num )]
1632
+ else :
1633
+ self .head = [HeadClass (torch .Tensor (initial_conditions ).view (1 , - 1 ), self .bases , n_last_layer_head , len (initial_conditions ))]
1634
+
1635
+ self .optimizer_args = optimizer_args or ()
1636
+ self .optimizer_kwargs = optimizer_kwargs or {}
1637
+
1638
+ if isinstance (optimizer , torch .optim .Optimizer ):
1639
+ self .optimizer = optimizer
1640
+ elif issubclass (optimizer , torch .optim .Optimizer ):
1641
+ params = chain .from_iterable (n .parameters () for n in self .head )
1642
+ self .optimizer = optimizer (params , * self .optimizer_args , ** self .optimizer_kwargs )
1643
+ else :
1644
+ raise TypeError (f"Unknown optimizer instance/type { self .optimizer } " )
1645
+
1646
+ super ().__init__ (
1647
+ diff_eqs = diff_eqs ,
1648
+ conditions = [NoCondition ()]* self .num ,
1649
+ train_generator = train_generator ,
1650
+ valid_generator = valid_generator ,
1651
+ nets = self .head ,
1652
+ system_parameters = system_parameters ,
1653
+ optimizer = self .optimizer ,
1654
+ n_batches_train = n_batches_train ,
1655
+ n_batches_valid = n_batches_valid ,
1656
+ loss_fn = loss_fn ,
1657
+ metrics = metrics
1658
+ )
1659
+
1660
+ def additional_loss (self , residuals , funcs , coords ):
1661
+
1662
+ loss = 0
1663
+ for i in range (len (self .nets )):
1664
+ out = self .nets [i ](torch .zeros ((1 ,1 )))
1665
+ loss += ((self .nets [i ].u_0 - out )** 2 ).mean ()
1666
+ return loss
1667
+
1668
+
1669
+ class UniversalSolver1D (ABC ):
1670
+ r"""A solver class for solving a family of ODEs (for different initial conditions and parameters)
1671
+
1672
+ :param ode_system:
1673
+ The ODE system to solve, which maps a torch.Tensor to a tuple of ODE residuals,
1674
+ both the input and output must have shape (n_samples, 1).
1675
+ :type ode_system: callable
1676
+ """
1677
+
1678
+ class Base (nn .Module ):
1679
+ def __init__ (self ):
1680
+ super ().__init__ ()
1681
+ self .linear_1 = nn .Linear (1 , 10 )
1682
+ self .linear_2 = nn .Linear (10 , 10 )
1683
+ self .linear_3 = nn .Linear (10 , 10 )
1684
+
1685
+ def forward (self , x ):
1686
+ x = self .linear_1 (x )
1687
+ x = torch .tanh (x )
1688
+ x = self .linear_2 (x )
1689
+ x = torch .tanh (x )
1690
+ x = self .linear_3 (x )
1691
+ x = torch .tanh (x )
1692
+ return x
1693
+
1694
+ def __init__ (self , diff_eqs , is_system = True ):
1695
+
1696
+ self .diff_eqs = diff_eqs
1697
+ self .is_system = is_system
1698
+
1699
+ self .t_min = None
1700
+ self .t_max = None
1701
+ self .train_generator = None
1702
+ self .valid_generator = None
1703
+
1704
+ def build (self ,u_0s = None ,
1705
+ system_parameters = [{}],
1706
+ BaseClass = Base ,
1707
+ HeadClass = None ,
1708
+ n_last_layer_head = 10 ,
1709
+ build_source = False ,
1710
+ optimizer = torch .optim .Adam ,
1711
+ optimizer_args = None , optimizer_kwargs = {"lr" :1e-3 },
1712
+ t_min = None ,
1713
+ t_max = None ,
1714
+ train_generator = None ,
1715
+ valid_generator = None ,
1716
+ n_batches_train = 1 ,
1717
+ n_batches_valid = 4 ,
1718
+ loss_fn = None ,
1719
+ metrics = None ):
1720
+
1721
+ r"""
1722
+ :param system_parameters:
1723
+ List of dictionaries of parameters for which the solver will be trained
1724
+ :type system_parameters: list[dict]
1725
+ :param BaseClass:
1726
+ Neural network class for base networks
1727
+ :type nets: torch.nn.Module
1728
+ :param n_last_layer_head:
1729
+ Number of neurons in the last layer for each network
1730
+ :type n_last_layer_head: int
1731
+ :param build_source:
1732
+ Boolean value for training the base networks or freezing their weights
1733
+ :type build_source: bool
1734
+ :param optimizer:
1735
+ Optimizer to be used for training.
1736
+ Defaults to a ``torch.optim.Adam`` instance that trains on all parameters of ``nets``.
1737
+ :type optimizer: ``torch.nn.optim.Optimizer``, optional
1738
+ :param t_min:
1739
+ Lower bound of input (start time).
1740
+ Ignored if ``train_generator`` and ``valid_generator`` are both set.
1741
+ :type t_min: float, optional
1742
+ :param t_max:
1743
+ Upper bound of input (start time).
1744
+ Ignored if ``train_generator`` and ``valid_generator`` are both set.
1745
+ :type t_max: float, optional
1746
+ :param train_generator:
1747
+ Generator for sampling training points,
1748
+ which must provide a ``.get_examples()`` method and a ``.size`` field.
1749
+ ``train_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1750
+ :type train_generator: `neurodiffeq.generators.BaseGenerator`, optional
1751
+ :param valid_generator:
1752
+ Generator for sampling validation points,
1753
+ which must provide a ``.get_examples()`` method and a ``.size`` field.
1754
+ ``valid_generator`` must be specified if ``t_min`` and ``t_max`` are not set.
1755
+ :type valid_generator: `neurodiffeq.generators.BaseGenerator`, optional
1756
+ :param n_batches_train:
1757
+ Number of batches to train in every epoch, where batch-size equals ``train_generator.size``.
1758
+ Defaults to 1.
1759
+ :type n_batches_train: int, optional
1760
+ :param n_batches_valid:
1761
+ Number of batches to validate in every epoch, where batch-size equals ``valid_generator.size``.
1762
+ Defaults to 4.
1763
+ :type n_batches_valid: int, optional
1764
+ :param loss_fn:
1765
+ The loss function used for training.
1766
+
1767
+ - If a str, must be present in the keys of `neurodiffeq.losses._losses`.
1768
+ - If a `torch.nn.modules.loss._Loss` instance, just pass the instance.
1769
+ - If any other callable, it must map
1770
+ A) a residual tensor (shape `(n_points, n_equations)`),
1771
+ B) a function values tuple (length `n_funcs`, each element a tensor of shape `(n_points, 1)`), and
1772
+ C) a coordinate values tuple (length `n_coords`, each element a tensor of shape `(n_coords, 1)`
1773
+ to a tensor of empty shape (i.e. a scalar). The returned tensor must be connected to the computational graph,
1774
+ so that backpropagation can be performed.
1775
+
1776
+ :type loss_fn:
1777
+ str or `torch.nn.moduesl.loss._Loss` or callable
1778
+ :param metrics:
1779
+ Additional metrics to be logged (besides loss). ``metrics`` should be a dict where
1780
+
1781
+ - Keys are metric names (e.g. 'analytic_mse');
1782
+ - Values are functions (callables) that computes the metric value.
1783
+ These functions must accept the same input as the differential equation ``ode_system``.
1784
+
1785
+ :type metrics: dict[str, callable], optional
1786
+ """
1787
+
1788
+ self .u_0s = u_0s
1789
+ self .system_parameters = system_parameters
1790
+ self .n_last_layer_head = n_last_layer_head
1791
+
1792
+ if t_min is not None :
1793
+ self .t_min = t_min
1794
+ if t_max is not None :
1795
+ self .t_max = t_max
1796
+
1797
+ if self .t_min is not None and self .t_max is not None :
1798
+ self .train_generator = Generator1D (32 , t_min = self .t_min , t_max = self .t_max , method = 'equally-spaced-noisy' )
1799
+ self .valid_generator = Generator1D (32 , t_min = self .t_min , t_max = self .t_max , method = 'equally-spaced-noisy' )
1800
+
1801
+ if train_generator is not None :
1802
+ self .train_generator = train_generator
1803
+ if valid_generator is not None :
1804
+ self .valid_generator = valid_generator
1805
+
1806
+ if self .u_0s is None :
1807
+ raise Exception ("ICs must be specified" )
1808
+ if self .train_generator is None or self .valid_generator is None :
1809
+ raise Exception (f"Train and valid generators cannot be None. Either provide `t_min` and `t_max` \
1810
+ or provide the generators as arguments" )
1811
+
1812
+ self .optimizer = optimizer
1813
+ self .optimizer_args = optimizer_args or ()
1814
+ self .optimizer_kwargs = optimizer_kwargs or {}
1815
+
1816
+ if build_source :
1817
+ if self .is_system :
1818
+ self .bases = [BaseClass () for _ in range (len (u_0s [0 ]))]
1819
+ else :
1820
+ self .bases = BaseClass ()
1821
+
1822
+ self .solvers_base = [_SingleSolver1D (
1823
+ bases = self .bases ,
1824
+ HeadClass = HeadClass ,
1825
+ initial_conditions = self .u_0s [i ],
1826
+ n_last_layer_head = n_last_layer_head ,
1827
+ diff_eqs = self .diff_eqs ,
1828
+ train_generator = self .train_generator ,
1829
+ valid_generator = self .valid_generator ,
1830
+ system_parameters = self .system_parameters [p ],
1831
+ optimizer = optimizer ,optimizer_args = optimizer_args , optimizer_kwargs = optimizer_kwargs ,
1832
+ n_batches_train = n_batches_train ,
1833
+ n_batches_valid = n_batches_valid ,
1834
+ loss_fn = loss_fn ,
1835
+ metrics = metrics ,
1836
+ is_system = self .is_system
1837
+ ) for i in range (len (u_0s )) for p in range (len (self .system_parameters ))]
1838
+ else :
1839
+ self .solvers_head = [_SingleSolver1D (
1840
+ bases = self .bases ,
1841
+ HeadClass = HeadClass ,
1842
+ initial_conditions = self .u_0s [i ],
1843
+ n_last_layer_head = self .n_last_layer_head ,
1844
+ diff_eqs = self .diff_eqs ,
1845
+ train_generator = self .train_generator ,
1846
+ valid_generator = self .valid_generator ,
1847
+ system_parameters = self .system_parameters [p ],
1848
+ optimizer = optimizer ,optimizer_args = optimizer_args , optimizer_kwargs = optimizer_kwargs ,
1849
+ n_batches_train = n_batches_train ,
1850
+ n_batches_valid = n_batches_valid ,
1851
+ loss_fn = loss_fn ,
1852
+ metrics = metrics ,
1853
+ is_system = self .is_system
1854
+ ) for i in range (len (self .u_0s )) for p in range (len (self .system_parameters ))]
1855
+
1856
+
1857
+ def fit (self , epochs = 10 , freeze_source = True ):
1858
+ r"""
1859
+ :param epochs:
1860
+ Number of epochs for training
1861
+ :type epochs: int
1862
+ :param freeze_source:
1863
+ Boolean value indicating whether to freeze the base networks or not
1864
+ :type freeze_source: bool
1865
+ """
1866
+
1867
+ if not freeze_source :
1868
+ for i in range (len (self .solvers_base )):
1869
+ self .solvers_base [i ].fit (max_epochs = epochs )
1870
+ else :
1871
+ if self .is_system :
1872
+ for net in self .bases :
1873
+ for param in net .parameters ():
1874
+ param .requires_grad = False
1875
+ else :
1876
+ for param in self .bases .parameters ():
1877
+ param .requires_grad = False
1878
+ for i in range (len (self .solvers_head )):
1879
+ self .solvers_head [i ].fit (max_epochs = epochs )
1880
+
1881
+
1882
+ def get_solution (self , base = False ):
1883
+ r"""
1884
+ :param base:
1885
+ Boolean value indicating whether to get solutions for those conditions for which the base
1886
+ was trained or solutions for those conditions for which only the last layer was trained
1887
+ :type base: bool
1888
+ :rtype: list[BaseSolution]
1889
+ """
1890
+
1891
+ if base :
1892
+ return [self .solvers_base [i ].get_solution () for i in range (len (self .solvers_base ))]
1893
+ else :
1894
+ return [self .solvers_head [i ].get_solution () for i in range (len (self .solvers_head ))]
0 commit comments