@@ -13,6 +13,7 @@ class BC(DomainND):
1313 def __init__ (self ):
1414 self .isPeriodic = False
1515 self .isInit = False
16+ self .isNeumann = False
1617
1718 def compile (self ):
1819 self .input = self .create_input ()
@@ -113,6 +114,62 @@ def create_target(self):
113114 print (fun_vals )
114115 self .val = convertTensor (np .reshape (fun_vals , (- 1 , 1 ))[self .nums ])
115116
117+ class FunctionNeumannBC (BC ):
118+ def __init__ (self , domain , fun , var , target , deriv_model , func_inputs , n_values = None ):
119+ self .n_values = n_values
120+ self .domain = domain
121+ self .fun = fun
122+ self .var = var
123+ self .target = target
124+ super ().__init__ ()
125+ self .deriv_model = [get_tf_model (model ) for model in deriv_model ]
126+ self .isNeumann = True
127+ self .func_inputs = func_inputs
128+ self .compile ()
129+ self .create_target ()
130+
131+ def get_input_upper_lower (self , var ):
132+ self .repeat = self .create_target_input_repeat (var , self .target )
133+ #self.lower_repeat = self.create_target_input_repeat(var, self.dict_["range"][0])
134+
135+ def compile (self ):
136+ self .input = []
137+ for var in self .var :
138+ self .dicts_ = [item for item in self .domain .domaindict if item ["identifier" ] != var ]
139+ self .dict_ = next (item for item in self .domain .domaindict if item ["identifier" ] == var )
140+ self .get_input_upper_lower (var )
141+ mesh = flatten_and_stack (multimesh (self .get_not_dims (var )))
142+ self .input .append (np .insert (mesh , self .domain .vars .index (var ), self .repeat .flatten (), axis = 1 ))
143+ # self.lower.append(np.insert(mesh, self.domain.vars.index(var), self.lower_repeat.flatten(), axis=1))
144+
145+ if self .n_values is not None :
146+ self .nums = np .random .randint (0 , high = len (self .input [0 ]), size = self .n_values )
147+ else :
148+ self .nums = np .random .randint (0 , high = len (self .input [0 ]), size = len (self .input [0 ]))
149+
150+ self .input = self .unroll (self .input )
151+ # self.lower = self.unroll(self.lower)
152+
153+ def u_x_model (self , u_model , inputs ):
154+ return [model (u_model , * inputs ) for model in self .deriv_model ]
155+
156+ def unroll (self , inp ):
157+ outer = []
158+ for _ , lst in enumerate (inp ):
159+ tmp = [np .reshape (vec , (- 1 , 1 ))[self .nums ] for vec in lst .T ]
160+ outer .append (np .asarray (tmp ))
161+ return outer
162+
163+ def create_target (self ):
164+ fun_vals = []
165+ for i , var_ in enumerate (self .func_inputs ):
166+ arg_list = []
167+ for j , var in enumerate (var_ ):
168+ var_dict = self .get_dict (var )
169+ arg_list .append (get_linspace (var_dict ))
170+ inp = flatten_and_stack (multimesh (arg_list ))
171+ fun_vals .append (self .fun [i ](* inp .T ))
172+ self .val = convertTensor (np .reshape (fun_vals , (- 1 , 1 ))[self .nums ])
116173
117174def get_function_out (func , var , dict_ ):
118175 linspace = get_linspace (dict_ )
@@ -122,6 +179,7 @@ def get_function_out(func, var, dict_):
122179class IC (BC ):
123180 def __init__ (self , domain , fun , var , n_values = None ):
124181 self .isPeriodic = False
182+ self .isNeumann = False
125183 self .isInit = True
126184 self .n_values = n_values
127185 self .domain = domain
@@ -206,7 +264,3 @@ def unroll(self, inp):
206264
207265
208266
209- # TODO Add Neumann BC
210-
211-
212-
0 commit comments