@@ -23,6 +23,9 @@ class PDEOperator(Data):
2323 variables of the function by `function_variables=[0]`, where `0` indicates
2424 the first variable `x`. If ``None``, then we assume the domains of the
2525 function and the PDE are the same.
26+ num_test: The number of functions for testing PDE loss. The testing functions
27+ for BCs/ICs are the same functions used for training. If ``None``, then the
28+ training functions will be used for testing.
2629
2730 Attributes:
2831 train_x_bc: A triple of three Numpy arrays (v, x, vx) fed into PIDeepONet for
@@ -43,6 +46,7 @@ def __init__(
4346 evaluation_points ,
4447 num_function ,
4548 function_variables = None ,
49+ num_test = None ,
4650 ):
4751 self .pde = pde
4852 self .func_space = function_space
@@ -53,6 +57,7 @@ def __init__(
5357 if function_variables is not None
5458 else list (range (pde .geom .dim ))
5559 )
60+ self .num_test = num_test
5661
5762 self .num_bcs = [n * self .num_func for n in self .pde .num_bcs ]
5863 self .train_x_bc = None
@@ -92,33 +97,53 @@ def losses(self, targets, outputs, loss, model):
9297 def train_next_batch (self , batch_size = None ):
9398 func_feats = self .func_space .random (self .num_func )
9499 func_vals = self .func_space .eval_batch (func_feats , self .eval_pts )
95-
96100 v , x , vx = self .bc_inputs (func_feats , func_vals )
97101 if self .pde .pde is not None :
98- # Branch input: v
99- v_pde = np .repeat (func_vals , len (self .pde .train_x_all ), axis = 0 )
102+ v_pde , x_pde , vx_pde = self .gen_inputs (
103+ func_feats , func_vals , self .pde .train_x_all
104+ )
100105 v = np .vstack ((v , v_pde ))
101- # Trunk input: x
102- x_pde = np .tile (self .pde .train_x_all , (self .num_func , 1 ))
103106 x = np .vstack ((x , x_pde ))
104- # vx
105- vx_pde = self .func_space .eval_batch (
106- func_feats , self .pde .train_x_all [:, self .func_vars ]
107- ).reshape (- 1 , 1 )
108107 vx = np .vstack ((vx , vx_pde ))
109-
110108 self .train_x = (v , x , vx )
111109 self .train_y = None
112110 return self .train_x , self .train_y
113111
114112 @run_if_all_none ("test_x" , "test_y" )
115113 def test (self ):
116- # TODO: Use different BC data from self.train_x
117- # TODO
118- self .test_x = self .train_x
119- self .test_y = self .train_y
114+ if self .num_test is None :
115+ self .test_x = self .train_x
116+ else :
117+ func_feats = self .func_space .random (self .num_test )
118+ func_vals = self .func_space .eval_batch (func_feats , self .eval_pts )
119+ # TODO: Use different BC data from self.train_x
120+ v , x , vx = self .train_x_bc
121+ if self .pde .pde is not None :
122+ v_pde , x_pde , vx_pde = self .gen_inputs (
123+ func_feats , func_vals , self .pde .test_x [sum (self .pde .num_bcs ) :]
124+ )
125+ v = np .vstack ((v , v_pde ))
126+ x = np .vstack ((x , x_pde ))
127+ vx = np .vstack ((vx , vx_pde ))
128+ self .test_x = (v , x , vx )
129+ self .test_y = None
120130 return self .test_x , self .test_y
121131
132+ def gen_inputs (self , func_feats , func_vals , points ):
133+ # Format:
134+ # v1, x_1
135+ # ...
136+ # v1, x_N1
137+ # v2, x_1
138+ # ...
139+ # v2, x_N1
140+ v = np .repeat (func_vals , len (points ), axis = 0 )
141+ x = np .tile (points , (len (func_feats ), 1 ))
142+ vx = self .func_space .eval_batch (func_feats , points [:, self .func_vars ]).reshape (
143+ - 1 , 1
144+ )
145+ return v , x , vx
146+
122147 def bc_inputs (self , func_feats , func_vals ):
123148 if not self .pde .bcs :
124149 self .train_x_bc = (
@@ -127,26 +152,15 @@ def bc_inputs(self, func_feats, func_vals):
127152 np .empty ((0 , 1 ), dtype = config .real (np )),
128153 )
129154 return self .train_x_bc
130- # Format:
131- # v1, x_bc1_1
132- # ...
133- # v1, x_bc1_N1
134- # v2, x_bc1_1
135- # ...
136- # v2, x_bc1_N1
137155 v , x , vx = [], [], []
138156 bcs_start = np .cumsum ([0 ] + self .pde .num_bcs )
139- for i , num_bc in enumerate (self .pde .num_bcs ):
157+ for i , _ in enumerate (self .pde .num_bcs ):
140158 beg , end = bcs_start [i ], bcs_start [i + 1 ]
141- # Branch input: v
142- v .append (np .repeat (func_vals , num_bc , axis = 0 ))
143- # Trunk input: x
144- x .append (np .tile (self .pde .train_x_bc [beg :end ], (self .num_func , 1 )))
145- # vx
146- vx .append (
147- self .func_space .eval_batch (
148- func_feats , self .pde .train_x_bc [beg :end , self .func_vars ]
149- ).reshape (- 1 , 1 )
159+ vi , xi , vxi = self .gen_inputs (
160+ func_feats , func_vals , self .pde .train_x_bc [beg :end ]
150161 )
162+ v .append (vi )
163+ x .append (xi )
164+ vx .append (vxi )
151165 self .train_x_bc = (np .vstack (v ), np .vstack (x ), np .vstack (vx ))
152166 return self .train_x_bc
0 commit comments