Skip to content

Commit 8fa99f9

Browse files
authored
Merge pull request #15 from levimcclenny/main
Added FunctionNeumannBC
2 parents f0758d1 + 9dd352a commit 8fa99f9

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

tensordiffeq/boundaries.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class BC(DomainND):
1313
def __init__(self):
1414
self.isPeriodic = False
1515
self.isInit = False
16+
self.isNeumann = False
1617

1718
def compile(self):
1819
self.input = self.create_input()
@@ -113,6 +114,62 @@ def create_target(self):
113114
print(fun_vals)
114115
self.val = convertTensor(np.reshape(fun_vals, (-1, 1))[self.nums])
115116

117+
class FunctionNeumannBC(BC):
118+
def __init__(self, domain, fun, var, target, deriv_model, func_inputs, n_values=None):
119+
self.n_values = n_values
120+
self.domain = domain
121+
self.fun = fun
122+
self.var = var
123+
self.target = target
124+
super().__init__()
125+
self.deriv_model = [get_tf_model(model) for model in deriv_model]
126+
self.isNeumann = True
127+
self.func_inputs = func_inputs
128+
self.compile()
129+
self.create_target()
130+
131+
def get_input_upper_lower(self, var):
132+
self.repeat = self.create_target_input_repeat(var, self.target)
133+
#self.lower_repeat = self.create_target_input_repeat(var, self.dict_["range"][0])
134+
135+
def compile(self):
136+
self.input = []
137+
for var in self.var:
138+
self.dicts_ = [item for item in self.domain.domaindict if item["identifier"] != var]
139+
self.dict_ = next(item for item in self.domain.domaindict if item["identifier"] == var)
140+
self.get_input_upper_lower(var)
141+
mesh = flatten_and_stack(multimesh(self.get_not_dims(var)))
142+
self.input.append(np.insert(mesh, self.domain.vars.index(var), self.repeat.flatten(), axis=1))
143+
# self.lower.append(np.insert(mesh, self.domain.vars.index(var), self.lower_repeat.flatten(), axis=1))
144+
145+
if self.n_values is not None:
146+
self.nums = np.random.randint(0, high=len(self.input[0]), size=self.n_values)
147+
else:
148+
self.nums = np.random.randint(0, high=len(self.input[0]), size=len(self.input[0]))
149+
150+
self.input = self.unroll(self.input)
151+
# self.lower = self.unroll(self.lower)
152+
153+
def u_x_model(self, u_model, inputs):
154+
return [model(u_model, *inputs) for model in self.deriv_model]
155+
156+
def unroll(self, inp):
157+
outer = []
158+
for _, lst in enumerate(inp):
159+
tmp = [np.reshape(vec, (-1, 1))[self.nums] for vec in lst.T]
160+
outer.append(np.asarray(tmp))
161+
return outer
162+
163+
def create_target(self):
164+
fun_vals = []
165+
for i, var_ in enumerate(self.func_inputs):
166+
arg_list = []
167+
for j, var in enumerate(var_):
168+
var_dict = self.get_dict(var)
169+
arg_list.append(get_linspace(var_dict))
170+
inp = flatten_and_stack(multimesh(arg_list))
171+
fun_vals.append(self.fun[i](*inp.T))
172+
self.val = convertTensor(np.reshape(fun_vals, (-1, 1))[self.nums])
116173

117174
def get_function_out(func, var, dict_):
118175
linspace = get_linspace(dict_)
@@ -122,6 +179,7 @@ def get_function_out(func, var, dict_):
122179
class IC(BC):
123180
def __init__(self, domain, fun, var, n_values=None):
124181
self.isPeriodic = False
182+
self.isNeumann = False
125183
self.isInit = True
126184
self.n_values = n_values
127185
self.domain = domain
@@ -206,7 +264,3 @@ def unroll(self, inp):
206264

207265

208266

209-
# TODO Add Neumann BC
210-
211-
212-

tensordiffeq/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def update_loss(self):
9090
loss_tmp = tf.math.add(loss_tmp, MSE(self.u_model(bc.input), bc.val))
9191
# Dirichlect BC, will need to add more cases for Neumann BC, etc as more
9292
# BC types are added
93+
if bc.isNeumann:
94+
for i, dim in enumerate(bc.var):
95+
for j, lst in enumerate(dim):
96+
for k, tup in enumerate(lst):
97+
target = tf.cast(bc.u_x_model(self.u_model, bc.input[i])[j][k], dtype=tf.float32)
98+
msq = MSE(bc.val, target)
99+
loss_tmp = tf.math.add(loss_tmp, msq)
93100
# This is true unless the BC loss can be evaluated using the MSE function explicitly
94101
else:
95102
loss_tmp = tf.math.add(loss_tmp, MSE(self.u_model(bc.input), bc.val))

0 commit comments

Comments
 (0)