Skip to content

Commit 4417e0f

Browse files
committed
Add FunctionDirichlectBC
1 parent 393620c commit 4417e0f

File tree

1 file changed

+50
-4
lines changed

1 file changed

+50
-4
lines changed

tensordiffeq/boundaries.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def __init__(self):
1414
self.isPeriodic = False
1515
self.isInit = False
1616

17-
1817
def compile(self):
1918
self.input = self.create_input()
2019

20+
# TODO Cleanup
2121
def preds_init(self, model):
2222
self.preds = model(self.input)
2323

@@ -39,7 +39,7 @@ def create_target_input_repeat(self, var, target):
3939
res = [val for key, val in dict_.items() if fidelity_key in key]
4040
fids.append(res)
4141
reps = np.prod(fids)
42-
if target is str:
42+
if type(target) is str:
4343
return np.repeat(self.dict_[(var + target)], reps)
4444
else:
4545
return np.repeat(target, reps)
@@ -64,8 +64,54 @@ def create_input(self):
6464
mesh = np.insert(mesh, self.domain.vars.index(self.var), repeated_value.flatten(), axis=1)
6565
return mesh
6666

67-
def loss(self):
68-
return MSE(self.preds, self.val)
67+
# def matching_dicts(obj, vars):
68+
# ret_dicts=[]
69+
# for variable in vars:
70+
# ret_dicts.extend(next(item for item in obj.domain.domaindict if item["identifier"] == variable))
71+
# return ret_dicts
72+
73+
74+
class FunctionDirichletBC(BC):
75+
def __init__(self, domain, fun, var, target, func_inputs, n_values=None):
76+
self.domain = domain
77+
self.fun = fun
78+
self.var = var
79+
self.target = target
80+
self.func_inputs = func_inputs
81+
self.n_values = n_values
82+
self.dicts_ = [item for item in self.domain.domaindict if item['identifier'] != self.var]
83+
self.dict_ = next(item for item in self.domain.domaindict if item["identifier"] == self.var)
84+
print(self.dict_)
85+
super().__init__()
86+
self.targets = self.dict_[var+target]
87+
self.compile()
88+
self.create_target()
89+
90+
def create_input(self):
91+
dims = self.get_not_dims(self.var)
92+
#dims = [get_linspace(dim) for dim in self.vars]
93+
# vals = np.reshape(fun_vals, (-1, len(self.vars)))
94+
mesh = flatten_and_stack(multimesh(dims))
95+
# dim_repeat = np.repeat(0.0, len(mesh))
96+
dim_repeat = self.create_target_input_repeat(self.var, self.target)
97+
mesh = np.insert(mesh, self.domain.vars.index(self.var), dim_repeat.flatten(), axis=1)
98+
if self.n_values is not None:
99+
self.nums = np.random.randint(0, high=len(mesh), size=self.n_values)
100+
mesh = mesh[self.nums]
101+
return mesh
102+
103+
def create_target(self):
104+
fun_vals = []
105+
for i, var_ in enumerate(self.func_inputs):
106+
arg_list = []
107+
for j, var in enumerate(var_):
108+
var_dict = self.get_dict(var)
109+
arg_list.append(get_linspace(var_dict))
110+
inp = flatten_and_stack(multimesh(arg_list))
111+
print(*inp.T)
112+
fun_vals.append(self.fun[i](*inp.T))
113+
print(fun_vals)
114+
self.val = convertTensor(np.reshape(fun_vals, (-1, 1))[self.nums])
69115

70116

71117
def get_function_out(func, var, dict_):

0 commit comments

Comments
 (0)