Skip to content

Commit 4adcde7

Browse files
committed
Bug fix: re-generate data each step
1 parent de7e3e3 commit 4adcde7

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

deepxde/data/ide.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .helper import one_function
99
from .pde import PDE
1010
from .. import config
11-
from ..utils import run_if_any_none
11+
from ..utils import run_if_all_none
1212

1313

1414
class IDE(PDE):
@@ -81,7 +81,7 @@ def losses_test():
8181

8282
return tf.cond(tf.equal(model.net.data_id, 0), losses_train, losses_test)
8383

84-
@run_if_any_none("train_x", "train_y")
84+
@run_if_all_none("train_x", "train_y")
8585
def train_next_batch(self, batch_size=None):
8686
self.train_x = self.train_points()
8787
x_bc = self.bc_points()
@@ -90,7 +90,7 @@ def train_next_batch(self, batch_size=None):
9090
self.train_y = self.func(self.train_x) if self.func else None
9191
return self.train_x, self.train_y
9292

93-
@run_if_any_none("test_x", "test_y")
93+
@run_if_all_none("test_x", "test_y")
9494
def test(self):
9595
if self.num_test is None:
9696
self.test_x = self.train_x[sum(self.num_bcs) :]

deepxde/data/pde.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .data import Data
99
from .. import config
10-
from ..utils import run_if_any_none
10+
from ..utils import run_if_all_none
1111

1212

1313
class PDE(Data):
@@ -68,14 +68,14 @@ def losses_test():
6868

6969
return tf.cond(tf.equal(model.net.data_id, 0), losses_train, losses_test)
7070

71-
@run_if_any_none("train_x", "train_y")
71+
@run_if_all_none("train_x", "train_y")
7272
def train_next_batch(self, batch_size=None):
7373
self.train_x = self.train_points()
7474
self.train_x = np.vstack((self.bc_points(), self.train_x))
7575
self.train_y = self.func(self.train_x) if self.func else None
7676
return self.train_x, self.train_y
7777

78-
@run_if_any_none("test_x", "test_y")
78+
@run_if_all_none("test_x", "test_y")
7979
def test(self):
8080
if self.num_test is None:
8181
self.test_x = self.train_x[sum(self.num_bcs) :]

0 commit comments

Comments
 (0)