@@ -18,14 +18,6 @@ def __init__(self):
1818 def compile (self ):
1919 self .input = self .create_input ()
2020
21- # TODO Cleanup
22- def preds_init (self , model ):
23- self .preds = model (self .input )
24-
25- @tf .function
26- def update_values (self , model ):
27- self .preds = model (self .input )
28-
2921 def get_dict (self , var ):
3022 return next (item for item in self .domain .domaindict if item ["identifier" ] == var )
3123
@@ -65,12 +57,6 @@ def create_input(self):
6557 mesh = np .insert (mesh , self .domain .vars .index (self .var ), repeated_value .flatten (), axis = 1 )
6658 return mesh
6759
68- # def matching_dicts(obj, vars):
69- # ret_dicts=[]
70- # for variable in vars:
71- # ret_dicts.extend(next(item for item in obj.domain.domaindict if item["identifier"] == variable))
72- # return ret_dicts
73-
7460
7561class FunctionDirichletBC (BC ):
7662 def __init__ (self , domain , fun , var , target , func_inputs , n_values = None ):
@@ -109,9 +95,7 @@ def create_target(self):
10995 var_dict = self .get_dict (var )
11096 arg_list .append (get_linspace (var_dict ))
11197 inp = flatten_and_stack (multimesh (arg_list ))
112- print (* inp .T )
11398 fun_vals .append (self .fun [i ](* inp .T ))
114- print (fun_vals )
11599 self .val = convertTensor (np .reshape (fun_vals , (- 1 , 1 ))[self .nums ])
116100
117101class FunctionNeumannBC (BC ):
@@ -130,7 +114,6 @@ def __init__(self, domain, fun, var, target, deriv_model, func_inputs, n_values=
130114
131115 def get_input_upper_lower (self , var ):
132116 self .repeat = self .create_target_input_repeat (var , self .target )
133- #self.lower_repeat = self.create_target_input_repeat(var, self.dict_["range"][0])
134117
135118 def compile (self ):
136119 self .input = []
@@ -140,7 +123,6 @@ def compile(self):
140123 self .get_input_upper_lower (var )
141124 mesh = flatten_and_stack (multimesh (self .get_not_dims (var )))
142125 self .input .append (np .insert (mesh , self .domain .vars .index (var ), self .repeat .flatten (), axis = 1 ))
143- # self.lower.append(np.insert(mesh, self.domain.vars.index(var), self.lower_repeat.flatten(), axis=1))
144126
145127 if self .n_values is not None :
146128 self .nums = np .random .randint (0 , high = len (self .input [0 ]), size = self .n_values )
@@ -178,21 +160,20 @@ def get_function_out(func, var, dict_):
178160
179161class IC (BC ):
180162 def __init__ (self , domain , fun , var , n_values = None ):
181- self .isPeriodic = False
182- self .isNeumann = False
183163 self .isInit = True
184164 self .n_values = n_values
185165 self .domain = domain
186166 self .fun = fun
187167 self .vars = var
168+ super ().__init__ ()
169+ self .isInit = True
188170 self .dicts_ = [item for item in self .domain .domaindict if item ['identifier' ] != self .domain .time_var ]
189171 self .dict_ = next (item for item in self .domain .domaindict if item ["identifier" ] == self .domain .time_var )
190172 self .compile ()
191173 self .create_target ()
192174
193175 def create_input (self ):
194176 dims = self .get_not_dims (self .domain .time_var )
195- # vals = np.reshape(fun_vals, (-1, len(self.vars)))
196177 mesh = flatten_and_stack (multimesh (dims ))
197178 t_repeat = np .repeat (0.0 , len (mesh ))
198179
0 commit comments