@@ -14,10 +14,10 @@ def __init__(self):
1414 self .isPeriodic = False
1515 self .isInit = False
1616
17-
1817 def compile (self ):
1918 self .input = self .create_input ()
2019
20+ # TODO Cleanup
2121 def preds_init (self , model ):
2222 self .preds = model (self .input )
2323
@@ -39,7 +39,7 @@ def create_target_input_repeat(self, var, target):
3939 res = [val for key , val in dict_ .items () if fidelity_key in key ]
4040 fids .append (res )
4141 reps = np .prod (fids )
42- if target is str :
42+ if type ( target ) is str :
4343 return np .repeat (self .dict_ [(var + target )], reps )
4444 else :
4545 return np .repeat (target , reps )
@@ -64,8 +64,54 @@ def create_input(self):
6464 mesh = np .insert (mesh , self .domain .vars .index (self .var ), repeated_value .flatten (), axis = 1 )
6565 return mesh
6666
67- def loss (self ):
68- return MSE (self .preds , self .val )
67+ # def matching_dicts(obj, vars):
68+ # ret_dicts=[]
69+ # for variable in vars:
70+ # ret_dicts.extend(next(item for item in obj.domain.domaindict if item["identifier"] == variable))
71+ # return ret_dicts
72+
73+
74+ class FunctionDirichletBC (BC ):
75+ def __init__ (self , domain , fun , var , target , func_inputs , n_values = None ):
76+ self .domain = domain
77+ self .fun = fun
78+ self .var = var
79+ self .target = target
80+ self .func_inputs = func_inputs
81+ self .n_values = n_values
82+ self .dicts_ = [item for item in self .domain .domaindict if item ['identifier' ] != self .var ]
83+ self .dict_ = next (item for item in self .domain .domaindict if item ["identifier" ] == self .var )
84+ print (self .dict_ )
85+ super ().__init__ ()
86+ self .targets = self .dict_ [var + target ]
87+ self .compile ()
88+ self .create_target ()
89+
90+ def create_input (self ):
91+ dims = self .get_not_dims (self .var )
92+ #dims = [get_linspace(dim) for dim in self.vars]
93+ # vals = np.reshape(fun_vals, (-1, len(self.vars)))
94+ mesh = flatten_and_stack (multimesh (dims ))
95+ # dim_repeat = np.repeat(0.0, len(mesh))
96+ dim_repeat = self .create_target_input_repeat (self .var , self .target )
97+ mesh = np .insert (mesh , self .domain .vars .index (self .var ), dim_repeat .flatten (), axis = 1 )
98+ if self .n_values is not None :
99+ self .nums = np .random .randint (0 , high = len (mesh ), size = self .n_values )
100+ mesh = mesh [self .nums ]
101+ return mesh
102+
103+ def create_target (self ):
104+ fun_vals = []
105+ for i , var_ in enumerate (self .func_inputs ):
106+ arg_list = []
107+ for j , var in enumerate (var_ ):
108+ var_dict = self .get_dict (var )
109+ arg_list .append (get_linspace (var_dict ))
110+ inp = flatten_and_stack (multimesh (arg_list ))
111+ print (* inp .T )
112+ fun_vals .append (self .fun [i ](* inp .T ))
113+ print (fun_vals )
114+ self .val = convertTensor (np .reshape (fun_vals , (- 1 , 1 ))[self .nums ])
69115
70116
71117def get_function_out (func , var , dict_ ):
0 commit comments