Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions deepxde/nn/pytorch/mionet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from .. import activations


class MIONetCartesianProd(NN):
"""MIONet with two input functions for Cartesian product format."""
class MIONet(NN):
"""Multiple-input operator network with two input functions."""

def __init__(
self,
Expand All @@ -29,9 +29,9 @@ def __init__(
self.activation_branch2 = activations.get(activation["branch2"])
self.activation_trunk = activations.get(activation["trunk"])
else:
self.activation_branch1 = (
self.activation_branch2
) = self.activation_trunk = activations.get(activation)
self.activation_branch1 = self.activation_branch2 = (
self.activation_trunk
) = activations.get(activation)
if callable(layer_sizes_branch1[1]):
# User-defined network
self.branch1 = layer_sizes_branch1[1]
Expand Down Expand Up @@ -81,6 +81,66 @@ def __init__(
self.merge_operation = merge_operation
self.output_merge_operation = output_merge_operation

def forward(self, inputs):
x_func1 = inputs[0]
x_func2 = inputs[1]
x_loc = inputs[2]
# Branch net to encode the input function
y_func1 = self.branch1(x_func1)
y_func2 = self.branch2(x_func2)
if self.merge_operation == "cat":
x_merger = torch.cat((y_func1, y_func2), 1)
else:
if y_func1.shape[-1] != y_func2.shape[-1]:
raise AssertionError(
"Output sizes of branch1 net and branch2 net do not match."
)
if self.merge_operation == "add":
x_merger = y_func1 + y_func2
elif self.merge_operation == "mul":
x_merger = torch.mul(y_func1, y_func2)
else:
raise NotImplementedError(
f"{self.merge_operation} operation to be implimented"
)
# Optional merger net
if self.merger is not None:
y_func = self.merger(x_merger)
else:
y_func = x_merger
# Trunk net to encode the domain of the output function
if self._input_transform is not None:
x_loc = self._input_transform(x_loc)
y_loc = self.trunk(x_loc)
if self.trunk_last_activation:
y_loc = self.activation_trunk(y_loc)
# Dot product
if y_func.shape[-1] != y_loc.shape[-1]:
raise AssertionError(
"Output sizes of merger net and trunk net do not match."
)
# output merger net
if self.output_merger is None:
y = torch.mul(y_func, y_loc)
y = torch.sum(y, dim=1, keepdim=True)
else:
if self.output_merge_operation == "mul":
y = torch.mul(y_func, y_loc)
elif self.output_merge_operation == "add":
y = y_func + y_loc
elif self.output_merge_operation == "cat":
y = torch.cat((y_func, y_loc), dim=1)
y = self.output_merger(y)
# Add bias
y += self.b
if self._output_transform is not None:
y = self._output_transform(inputs, y)
return y


class MIONetCartesianProd(MIONet):
"""MIONet with two input functions for Cartesian product format."""

def forward(self, inputs):
x_func1 = inputs[0]
x_func2 = inputs[1]
Expand Down Expand Up @@ -170,9 +230,9 @@ def __init__(
self.activation_trunk = activations.get(activation["trunk"])
self.activation_merger = activations.get(activation["merger"])
else:
self.activation_branch1 = (
self.activation_branch2
) = self.activation_trunk = activations.get(activation)
self.activation_branch1 = self.activation_branch2 = (
self.activation_trunk
) = activations.get(activation)
self.pod_basis = torch.as_tensor(pod_basis, dtype=torch.float32)
if callable(layer_sizes_branch1[1]):
# User-defined network
Expand Down