Skip to content

Commit 60e54ba

Browse files
committed
cleanup boundaries.py comments and fix some OOP
1 parent 9dd352a commit 60e54ba

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

tensordiffeq/boundaries.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ def __init__(self):
1818
def compile(self):
1919
self.input = self.create_input()
2020

21-
# TODO Cleanup
22-
def preds_init(self, model):
23-
self.preds = model(self.input)
24-
25-
@tf.function
26-
def update_values(self, model):
27-
self.preds = model(self.input)
28-
2921
def get_dict(self, var):
3022
return next(item for item in self.domain.domaindict if item["identifier"] == var)
3123

@@ -65,12 +57,6 @@ def create_input(self):
6557
mesh = np.insert(mesh, self.domain.vars.index(self.var), repeated_value.flatten(), axis=1)
6658
return mesh
6759

68-
# def matching_dicts(obj, vars):
69-
# ret_dicts=[]
70-
# for variable in vars:
71-
# ret_dicts.extend(next(item for item in obj.domain.domaindict if item["identifier"] == variable))
72-
# return ret_dicts
73-
7460

7561
class FunctionDirichletBC(BC):
7662
def __init__(self, domain, fun, var, target, func_inputs, n_values=None):
@@ -109,9 +95,7 @@ def create_target(self):
10995
var_dict = self.get_dict(var)
11096
arg_list.append(get_linspace(var_dict))
11197
inp = flatten_and_stack(multimesh(arg_list))
112-
print(*inp.T)
11398
fun_vals.append(self.fun[i](*inp.T))
114-
print(fun_vals)
11599
self.val = convertTensor(np.reshape(fun_vals, (-1, 1))[self.nums])
116100

117101
class FunctionNeumannBC(BC):
@@ -130,7 +114,6 @@ def __init__(self, domain, fun, var, target, deriv_model, func_inputs, n_values=
130114

131115
def get_input_upper_lower(self, var):
132116
self.repeat = self.create_target_input_repeat(var, self.target)
133-
#self.lower_repeat = self.create_target_input_repeat(var, self.dict_["range"][0])
134117

135118
def compile(self):
136119
self.input = []
@@ -140,7 +123,6 @@ def compile(self):
140123
self.get_input_upper_lower(var)
141124
mesh = flatten_and_stack(multimesh(self.get_not_dims(var)))
142125
self.input.append(np.insert(mesh, self.domain.vars.index(var), self.repeat.flatten(), axis=1))
143-
# self.lower.append(np.insert(mesh, self.domain.vars.index(var), self.lower_repeat.flatten(), axis=1))
144126

145127
if self.n_values is not None:
146128
self.nums = np.random.randint(0, high=len(self.input[0]), size=self.n_values)
@@ -178,21 +160,20 @@ def get_function_out(func, var, dict_):
178160

179161
class IC(BC):
180162
def __init__(self, domain, fun, var, n_values=None):
181-
self.isPeriodic = False
182-
self.isNeumann = False
183163
self.isInit = True
184164
self.n_values = n_values
185165
self.domain = domain
186166
self.fun = fun
187167
self.vars = var
168+
super().__init__()
169+
self.isInit = True
188170
self.dicts_ = [item for item in self.domain.domaindict if item['identifier'] != self.domain.time_var]
189171
self.dict_ = next(item for item in self.domain.domaindict if item["identifier"] == self.domain.time_var)
190172
self.compile()
191173
self.create_target()
192174

193175
def create_input(self):
194176
dims = self.get_not_dims(self.domain.time_var)
195-
# vals = np.reshape(fun_vals, (-1, len(self.vars)))
196177
mesh = flatten_and_stack(multimesh(dims))
197178
t_repeat = np.repeat(0.0, len(mesh))
198179

0 commit comments

Comments
 (0)