Skip to content

Commit f0126d5

Browse files
committed
Add a new Data as a constraint
1 parent f3c9889 commit f0126d5

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

deepxde/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import absolute_import
22

3+
from .constraint import Constraint
34
from .dataset import DataSet
45
from .fpde import FPDE
56
from .fpde import TimeFPDE
@@ -15,6 +16,7 @@
1516

1617

1718
__all__ = [
19+
"Constraint",
1820
"DataSet",
1921
"FPDE",
2022
"Func",

deepxde/data/constraint.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
from .data import Data
6+
from .. import config
7+
from ..backend import tf
8+
9+
10+
class Constraint(Data):
11+
"""General constraints.
12+
"""
13+
14+
def __init__(self, constraint, train_x, test_x):
15+
self.constraint = constraint
16+
self.train_x = train_x
17+
self.test_x = test_x
18+
19+
def losses(self, targets, outputs, loss, model):
20+
f = tf.cond(
21+
tf.equal(model.net.data_id, 0),
22+
lambda: self.constraint(model.net.inputs, outputs, self.train_x),
23+
lambda: self.constraint(model.net.inputs, outputs, self.test_x),
24+
)
25+
return loss(tf.zeros(tf.shape(f), dtype=config.real(tf)), f)
26+
27+
def train_next_batch(self, batch_size=None):
28+
return self.train_x, None
29+
30+
def test(self):
31+
return self.test_x, None

0 commit comments

Comments
 (0)