File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change 88from .helper import one_function
99from .pde import PDE
1010from .. import config
11- from ..utils import run_if_any_none
11+ from ..utils import run_if_all_none
1212
1313
1414class IDE (PDE ):
@@ -81,7 +81,7 @@ def losses_test():
8181
8282 return tf .cond (tf .equal (model .net .data_id , 0 ), losses_train , losses_test )
8383
84- @run_if_any_none ("train_x" , "train_y" )
84+ @run_if_all_none ("train_x" , "train_y" )
8585 def train_next_batch (self , batch_size = None ):
8686 self .train_x = self .train_points ()
8787 x_bc = self .bc_points ()
@@ -90,7 +90,7 @@ def train_next_batch(self, batch_size=None):
9090 self .train_y = self .func (self .train_x ) if self .func else None
9191 return self .train_x , self .train_y
9292
93- @run_if_any_none ("test_x" , "test_y" )
93+ @run_if_all_none ("test_x" , "test_y" )
9494 def test (self ):
9595 if self .num_test is None :
9696 self .test_x = self .train_x [sum (self .num_bcs ) :]
Original file line number Diff line number Diff line change 77
88from .data import Data
99from .. import config
10- from ..utils import run_if_any_none
10+ from ..utils import run_if_all_none
1111
1212
1313class PDE (Data ):
@@ -68,14 +68,14 @@ def losses_test():
6868
6969 return tf .cond (tf .equal (model .net .data_id , 0 ), losses_train , losses_test )
7070
71- @run_if_any_none ("train_x" , "train_y" )
71+ @run_if_all_none ("train_x" , "train_y" )
7272 def train_next_batch (self , batch_size = None ):
7373 self .train_x = self .train_points ()
7474 self .train_x = np .vstack ((self .bc_points (), self .train_x ))
7575 self .train_y = self .func (self .train_x ) if self .func else None
7676 return self .train_x , self .train_y
7777
78- @run_if_any_none ("test_x" , "test_y" )
78+ @run_if_all_none ("test_x" , "test_y" )
7979 def test (self ):
8080 if self .num_test is None :
8181 self .test_x = self .train_x [sum (self .num_bcs ) :]
You can’t perform that action at this time.
0 commit comments