Skip to content

Commit 2d2639a

Browse files
authored
Add JAX backend support
2 parents b2c1c94 + 45a133a commit 2d2639a

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed

sbinn/sbinn_jax.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import numpy as np
2+
import deepxde as dde
3+
import variable_to_parameter_transform
4+
import jax.numpy as jnp
5+
import jax
6+
7+
8+
def sbinn(data_t, data_y, meal_t, meal_q):
9+
def get_variable(v, var):
10+
var = var
11+
low, up = v * 0.2, v * 1.8
12+
l = (up - low) / 2
13+
v1 = l * jnp.tanh(var) + l + low
14+
return v1
15+
16+
E_ = dde.Variable(0.0)
17+
tp_ = dde.Variable(0.0)
18+
ti_ = dde.Variable(0.0)
19+
td_ = dde.Variable(0.0)
20+
k_ = dde.Variable(0.0)
21+
Rm_ = dde.Variable(0.0)
22+
a1_ = dde.Variable(0.0)
23+
C1_ = dde.Variable(0.0)
24+
C2_ = dde.Variable(0.0)
25+
C4_ = dde.Variable(0.0)
26+
C5_ = dde.Variable(0.0)
27+
Ub_ = dde.Variable(0.0)
28+
U0_ = dde.Variable(0.0)
29+
Um_ = dde.Variable(0.0)
30+
Rg_ = dde.Variable(0.0)
31+
alpha_ = dde.Variable(0.0)
32+
beta_ = dde.Variable(0.0)
33+
34+
var_list_ = [
35+
E_,
36+
tp_,
37+
ti_,
38+
td_,
39+
k_,
40+
Rm_,
41+
a1_,
42+
C1_,
43+
C2_,
44+
C4_,
45+
C5_,
46+
Ub_,
47+
U0_,
48+
Um_,
49+
Rg_,
50+
alpha_,
51+
beta_,
52+
]
53+
54+
def ODE(t, y, unknowns=[var.value for var in var_list_]):
55+
(
56+
E_,
57+
tp_,
58+
ti_,
59+
td_,
60+
k_,
61+
Rm_,
62+
a1_,
63+
C1_,
64+
C2_,
65+
C4_,
66+
C5_,
67+
Ub_,
68+
U0_,
69+
Um_,
70+
Rg_,
71+
alpha_,
72+
beta_,
73+
) = unknowns
74+
if len(y[0].shape) == 1:
75+
Ip = y[0][0:1]
76+
Ii = y[0][1:2]
77+
G = y[0][2:3]
78+
h1 = y[0][3:4]
79+
h2 = y[0][4:5]
80+
h3 = y[0][5:6]
81+
else:
82+
Ip = y[0][:, 0:1]
83+
Ii = y[0][:, 1:2]
84+
G = y[0][:, 2:3]
85+
h1 = y[0][:, 3:4]
86+
h2 = y[0][:, 4:5]
87+
h3 = y[0][:, 5:6]
88+
89+
Vp = 3
90+
Vi = 11
91+
Vg = 10
92+
E = (jnp.tanh(E_) + 1) * 0.1 + 0.1
93+
tp = (jnp.tanh(tp_) + 1) * 2 + 4
94+
ti = (jnp.tanh(ti_) + 1) * 40 + 60
95+
td = (jnp.tanh(td_) + 1) * 25 / 6 + 25 / 3
96+
k = get_variable(0.0083, k_)
97+
Rm = get_variable(209, Rm_)
98+
a1 = get_variable(6.6, a1_)
99+
C1 = get_variable(300, C1_)
100+
C2 = get_variable(144, C2_)
101+
C3 = 100
102+
C4 = get_variable(80, C4_)
103+
C5 = get_variable(26, C5_)
104+
Ub = get_variable(72, Ub_)
105+
U0 = get_variable(4, U0_)
106+
Um = get_variable(90, Um_)
107+
Rg = get_variable(180, Rg_)
108+
alpha = get_variable(7.5, alpha_)
109+
beta = get_variable(1.772, beta_)
110+
111+
f1 = Rm * jax.nn.sigmoid(G / (Vg * C1) - a1)
112+
f2 = Ub * (1 - jnp.exp(-G / (Vg * C2)))
113+
kappa = (1 / Vi + 1 / (E * ti)) / C4
114+
f3 = (U0 + Um / (1 + jnp.pow(jnp.maximum(kappa * Ii, 1e-3), -beta))) / (Vg * C3)
115+
f4 = Rg * jax.nn.sigmoid(alpha * (1 - h3 / (Vp * C5)))
116+
dt = t - meal_t
117+
IG = jnp.sum(
118+
0.5 * meal_q * k * jnp.exp(-k * dt) * (jnp.sign(dt) + 1),
119+
axis=1,
120+
keepdims=True,
121+
)
122+
tmp = E * (Ip / Vp - Ii / Vi)
123+
dIP_dt = dde.grad.jacobian(y, t, i=0, j=0)[0]
124+
dIi_dt = dde.grad.jacobian(y, t, i=1, j=0)[0]
125+
dG_dt = dde.grad.jacobian(y, t, i=2, j=0)[0]
126+
dh1_dt = dde.grad.jacobian(y, t, i=3, j=0)[0]
127+
dh2_dt = dde.grad.jacobian(y, t, i=4, j=0)[0]
128+
dh3_dt = dde.grad.jacobian(y, t, i=5, j=0)[0]
129+
return [
130+
dIP_dt - (f1 - tmp - Ip / tp),
131+
dIi_dt - (tmp - Ii / ti),
132+
dG_dt - (f4 + IG - f2 - f3 * G),
133+
dh1_dt - (Ip - h1) / td,
134+
dh2_dt - (h1 - h2) / td,
135+
dh3_dt - (h2 - h3) / td,
136+
]
137+
138+
geom = dde.geometry.TimeDomain(data_t[0, 0], data_t[-1, 0])
139+
140+
# Observes
141+
n = len(data_t)
142+
idx = np.append(
143+
np.random.choice(np.arange(1, n - 1), size=n // 5, replace=False), [0, n - 1]
144+
)
145+
observe_y2 = dde.PointSetBC(data_t[idx], data_y[idx, 2:3], component=2)
146+
147+
np.savetxt("glucose_input.dat", np.hstack((data_t[idx], data_y[idx, 2:3])))
148+
149+
data = dde.data.PDE(geom, ODE, [observe_y2], anchors=data_t)
150+
151+
net = dde.maps.FNN([1] + [128] * 3 + [6], "swish", "Glorot normal")
152+
153+
def feature_transform(t):
154+
t = 0.01 * t
155+
return jnp.concat(
156+
(
157+
t,
158+
jnp.sin(t),
159+
jnp.sin(2 * t),
160+
jnp.sin(3 * t),
161+
jnp.sin(4 * t),
162+
jnp.sin(5 * t),
163+
),
164+
axis=1,
165+
)
166+
167+
net.apply_feature_transform(feature_transform)
168+
169+
def output_transform(t, y):
170+
idx = 1799
171+
k = (data_y[idx] - data_y[0]) / (data_t[idx] - data_t[0])
172+
b = (data_t[idx] * data_y[0] - data_t[0] * data_y[idx]) / (
173+
data_t[idx] - data_t[0]
174+
)
175+
linear = k * t + b
176+
factor = jnp.tanh(t) * jnp.tanh(idx - t)
177+
return linear + factor * jnp.array([1, 1, 1e2, 1, 1, 1]) * y
178+
179+
net.apply_output_transform(output_transform)
180+
181+
model = dde.Model(data, net)
182+
183+
firsttrain = 10000
184+
callbackperiod = 1000
185+
maxepochs = 1000000
186+
187+
model.compile("adam", lr=1e-3, loss_weights=[0, 0, 0, 0, 0, 0, 1e-2])
188+
model.train(iterations=firsttrain, display_every=1000)
189+
model.compile(
190+
"adam",
191+
lr=1e-3,
192+
loss_weights=[1, 1, 1e-2, 1, 1, 1, 1e-2],
193+
external_trainable_variables=var_list_,
194+
)
195+
variablefilename = "variables.csv"
196+
variable = dde.callbacks.VariableValue(
197+
var_list_, period=callbackperiod, filename=variablefilename
198+
)
199+
losshistory, train_state = model.train(
200+
iterations=maxepochs, display_every=1000, callbacks=[variable]
201+
)
202+
203+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
204+
205+
206+
gluc_data = np.hsplit(np.loadtxt("glucose.dat"), [1])
207+
meal_data = np.hsplit(np.loadtxt("meal.dat"), [4])
208+
209+
t = gluc_data[0]
210+
y = gluc_data[1]
211+
meal_t = meal_data[0]
212+
meal_q = meal_data[1]
213+
214+
sbinn(
215+
t[:1800],
216+
y[:1800],
217+
meal_t,
218+
meal_q,
219+
)
220+
221+
variable_to_parameter_transform.variable_file(10000, 1000, 1000000, "variables.csv")

0 commit comments

Comments
 (0)