diff --git a/deepxde/nn/pytorch/mionet.py b/deepxde/nn/pytorch/mionet.py index 6001f5ffb..7f723134f 100644 --- a/deepxde/nn/pytorch/mionet.py +++ b/deepxde/nn/pytorch/mionet.py @@ -5,8 +5,8 @@ from .. import activations -class MIONetCartesianProd(NN): - """MIONet with two input functions for Cartesian product format.""" +class MIONet(NN): + """Multiple-input operator network with two input functions.""" def __init__( self, @@ -29,9 +29,9 @@ def __init__( self.activation_branch2 = activations.get(activation["branch2"]) self.activation_trunk = activations.get(activation["trunk"]) else: - self.activation_branch1 = ( - self.activation_branch2 - ) = self.activation_trunk = activations.get(activation) + self.activation_branch1 = self.activation_branch2 = ( + self.activation_trunk + ) = activations.get(activation) if callable(layer_sizes_branch1[1]): # User-defined network self.branch1 = layer_sizes_branch1[1] @@ -81,6 +81,66 @@ def __init__( self.merge_operation = merge_operation self.output_merge_operation = output_merge_operation + def forward(self, inputs): + x_func1 = inputs[0] + x_func2 = inputs[1] + x_loc = inputs[2] + # Branch net to encode the input function + y_func1 = self.branch1(x_func1) + y_func2 = self.branch2(x_func2) + if self.merge_operation == "cat": + x_merger = torch.cat((y_func1, y_func2), 1) + else: + if y_func1.shape[-1] != y_func2.shape[-1]: + raise AssertionError( + "Output sizes of branch1 net and branch2 net do not match." + ) + if self.merge_operation == "add": + x_merger = y_func1 + y_func2 + elif self.merge_operation == "mul": + x_merger = torch.mul(y_func1, y_func2) + else: + raise NotImplementedError( + f"{self.merge_operation} operation to be implimented" + ) + # Optional merger net + if self.merger is not None: + y_func = self.merger(x_merger) + else: + y_func = x_merger + # Trunk net to encode the domain of the output function + if self._input_transform is not None: + x_loc = self._input_transform(x_loc) + y_loc = self.trunk(x_loc) + if self.trunk_last_activation: + y_loc = self.activation_trunk(y_loc) + # Dot product + if y_func.shape[-1] != y_loc.shape[-1]: + raise AssertionError( + "Output sizes of merger net and trunk net do not match." + ) + # output merger net + if self.output_merger is None: + y = torch.mul(y_func, y_loc) + y = torch.sum(y, dim=1, keepdim=True) + else: + if self.output_merge_operation == "mul": + y = torch.mul(y_func, y_loc) + elif self.output_merge_operation == "add": + y = y_func + y_loc + elif self.output_merge_operation == "cat": + y = torch.cat((y_func, y_loc), dim=1) + y = self.output_merger(y) + # Add bias + y += self.b + if self._output_transform is not None: + y = self._output_transform(inputs, y) + return y + + +class MIONetCartesianProd(MIONet): + """MIONet with two input functions for Cartesian product format.""" + def forward(self, inputs): x_func1 = inputs[0] x_func2 = inputs[1] @@ -170,9 +230,9 @@ def __init__( self.activation_trunk = activations.get(activation["trunk"]) self.activation_merger = activations.get(activation["merger"]) else: - self.activation_branch1 = ( - self.activation_branch2 - ) = self.activation_trunk = activations.get(activation) + self.activation_branch1 = self.activation_branch2 = ( + self.activation_trunk + ) = activations.get(activation) self.pod_basis = torch.as_tensor(pod_basis, dtype=torch.float32) if callable(layer_sizes_branch1[1]): # User-defined network