Skip to content

Commit bf3bb6f

Browse files
authored
Merge pull request #208 from KuangYu/devel
Devel
2 parents 22845c9 + e634dc4 commit bf3bb6f

4 files changed

Lines changed: 52 additions & 25 deletions

File tree

dmff/torch_tools.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,64 @@ def j2t_pytree(v):
4141

4242
def wrap_torch_potential_kernel(potential_t):
4343

44-
@partial(jax.custom_jvp, nondiff_argnums=(2,))
44+
# jvp, good for push-forward mode
45+
# @partial(jax.custom_jvp, nondiff_argnums=(2,))
46+
# def potential(positions, box, pairs, params):
47+
# res = potential_t(j2t_pytree(positions), \
48+
# j2t_pytree(box), \
49+
# np.array(pairs), \
50+
# j2t_pytree(params))
51+
# return res
52+
53+
# @potential.defjvp
54+
# def potential_jvp(pairs, primals, tangents):
55+
# positions, box, params = primals
56+
# dpositions, dbox, dparams = tangents
57+
# # convert inputs to torch
58+
# positions_t = j2t_pytree(positions)
59+
# box_t = j2t_pytree(box)
60+
# params_t = j2t_pytree(params)
61+
# # do fwd and bwd in torch
62+
# primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t)
63+
# primal_out_torch.backward()
64+
# # read gradient in torch
65+
# g_positions = t2j_extract_grad(positions_t)
66+
# g_box = t2j_extract_grad(box_t)
67+
# g_params = t2j_extract_grad(params_t)
68+
# # prepare output
69+
# primal_out = t2j(primal_out_torch)
70+
# tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box)
71+
# tangents_leaves = jax.tree.leaves(dparams)
72+
# grad_leaves = jax.tree.leaves(g_params)
73+
# for x, y in zip(tangents_leaves, grad_leaves):
74+
# tangent_out += jnp.sum(x * y)
75+
# return primal_out, tangent_out
76+
77+
# vjp: good for backward
78+
@partial(jax.custom_vjp, nondiff_argnums=(2,))
4579
def potential(positions, box, pairs, params):
46-
res = potential_t(j2t_pytree(positions), \
47-
j2t_pytree(box), \
48-
np.array(pairs), \
80+
res = potential_t(j2t_pytree(positions),
81+
j2t_pytree(box),
82+
np.array(pairs),
4983
j2t_pytree(params))
5084
return res
5185

52-
@potential.defjvp
53-
def potential_jvp(pairs, primals, tangents):
54-
positions, box, params = primals
55-
dpositions, dbox, dparams = tangents
56-
# convert inputs to torch
57-
positions_t = j2t_pytree(positions)
86+
def potential_fwd(positions, box, pairs, params):
87+
pos_t = j2t_pytree(positions)
5888
box_t = j2t_pytree(box)
89+
pairs = np.array(pairs)
5990
params_t = j2t_pytree(params)
60-
# do fwd and bwd in torch
61-
primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t)
62-
primal_out_torch.backward()
63-
# read gradient in torch
64-
g_positions = t2j_extract_grad(positions_t)
65-
g_box = t2j_extract_grad(box_t)
66-
g_params = t2j_extract_grad(params_t)
67-
# prepare output
68-
primal_out = t2j(primal_out_torch)
69-
tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box)
70-
tangents_leaves = jax.tree.leaves(dparams)
71-
grad_leaves = jax.tree.leaves(g_params)
72-
for x, y in zip(tangents_leaves, grad_leaves):
73-
tangent_out += jnp.sum(x * y)
74-
return primal_out, tangent_out
91+
energy = potential_t(pos_t, box_t, pairs, params_t)
92+
energy.backward()
93+
grads = (t2j_extract_grad(pos_t),
94+
t2j_extract_grad(box_t),
95+
t2j_extract_grad(params_t))
96+
return t2j(energy), grads
97+
98+
def potential_bwd(pairs, res, g):
99+
return res[0]*g, res[1]*g, jax.tree.map(lambda x: x*g, res[2])
100+
101+
potential.defvjp(potential_fwd, potential_bwd)
75102

76103
return potential
77104

examples/eann/eann_model.pickle

-329 Bytes
Binary file not shown.

tests/data/eann_model.pickle

-329 Bytes
Binary file not shown.

tests/data/water_eann.pickle

-228 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)