Skip to content

Commit 23bbb2b

Browse files
committed
make base TorchModel abstract class
1 parent 47f5c01 commit 23bbb2b

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pydens/model_torch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Contains classes for solving differential equations with neural networks. """
22

3+
from abc import ABC, abstractmethod
34
from contextvars import ContextVar, copy_context
45

56
import numpy as np
@@ -13,7 +14,7 @@
1314

1415
current_model = ContextVar("current_model")
1516

16-
class TorchModel(nn.Module):
17+
class TorchModel(ABC, nn.Module):
1718
""" Pytorch model for solving differential equations with neural networks.
1819
"""
1920
def __init__(self, initial_condition=None, boundary_condition=None, ndims=1, nparams=0, **kwargs):
@@ -38,6 +39,11 @@ def __init__(self, initial_condition=None, boundary_condition=None, ndims=1, npa
3839
# and boundary conditions.
3940
self.log_scale = nn.Parameter(torch.tensor(0.0, requires_grad=True))
4041

42+
@abstractmethod
43+
def forward(self, xs):
44+
""" Forward of the model-network. """
45+
pass
46+
4147
def freeze_trainable(self, layers=None, variables=None):
4248
""" Freeze layers and trainable variables.
4349
"""

0 commit comments

Comments
 (0)