@@ -41,37 +41,64 @@ def j2t_pytree(v):
4141
4242def wrap_torch_potential_kernel (potential_t ):
4343
44- @partial (jax .custom_jvp , nondiff_argnums = (2 ,))
44+ # jvp, good for push-forward mode
45+ # @partial(jax.custom_jvp, nondiff_argnums=(2,))
46+ # def potential(positions, box, pairs, params):
47+ # res = potential_t(j2t_pytree(positions), \
48+ # j2t_pytree(box), \
49+ # np.array(pairs), \
50+ # j2t_pytree(params))
51+ # return res
52+
53+ # @potential.defjvp
54+ # def potential_jvp(pairs, primals, tangents):
55+ # positions, box, params = primals
56+ # dpositions, dbox, dparams = tangents
57+ # # convert inputs to torch
58+ # positions_t = j2t_pytree(positions)
59+ # box_t = j2t_pytree(box)
60+ # params_t = j2t_pytree(params)
61+ # # do fwd and bwd in torch
62+ # primal_out_torch = potential_t(positions_t, box_t, np.array(pairs), params_t)
63+ # primal_out_torch.backward()
64+ # # read gradient in torch
65+ # g_positions = t2j_extract_grad(positions_t)
66+ # g_box = t2j_extract_grad(box_t)
67+ # g_params = t2j_extract_grad(params_t)
68+ # # prepare output
69+ # primal_out = t2j(primal_out_torch)
70+ # tangent_out = jnp.sum(g_positions * dpositions) + jnp.sum(g_box * box)
71+ # tangents_leaves = jax.tree.leaves(dparams)
72+ # grad_leaves = jax.tree.leaves(g_params)
73+ # for x, y in zip(tangents_leaves, grad_leaves):
74+ # tangent_out += jnp.sum(x * y)
75+ # return primal_out, tangent_out
76+
77+ # vjp: good for backward
78+ @partial (jax .custom_vjp , nondiff_argnums = (2 ,))
4579 def potential (positions , box , pairs , params ):
46- res = potential_t (j2t_pytree (positions ), \
47- j2t_pytree (box ), \
48- np .array (pairs ), \
80+ res = potential_t (j2t_pytree (positions ),
81+ j2t_pytree (box ),
82+ np .array (pairs ),
4983 j2t_pytree (params ))
5084 return res
5185
52- @potential .defjvp
53- def potential_jvp (pairs , primals , tangents ):
54- positions , box , params = primals
55- dpositions , dbox , dparams = tangents
56- # convert inputs to torch
57- positions_t = j2t_pytree (positions )
86+ def potential_fwd (positions , box , pairs , params ):
87+ pos_t = j2t_pytree (positions )
5888 box_t = j2t_pytree (box )
89+ pairs = np .array (pairs )
5990 params_t = j2t_pytree (params )
60- # do fwd and bwd in torch
61- primal_out_torch = potential_t (positions_t , box_t , np .array (pairs ), params_t )
62- primal_out_torch .backward ()
63- # read gradient in torch
64- g_positions = t2j_extract_grad (positions_t )
65- g_box = t2j_extract_grad (box_t )
66- g_params = t2j_extract_grad (params_t )
67- # prepare output
68- primal_out = t2j (primal_out_torch )
69- tangent_out = jnp .sum (g_positions * dpositions ) + jnp .sum (g_box * box )
70- tangents_leaves = jax .tree .leaves (dparams )
71- grad_leaves = jax .tree .leaves (g_params )
72- for x , y in zip (tangents_leaves , grad_leaves ):
73- tangent_out += jnp .sum (x * y )
74- return primal_out , tangent_out
91+ energy = potential_t (pos_t , box_t , pairs , params_t )
92+ energy .backward ()
93+ grads = (t2j_extract_grad (pos_t ),
94+ t2j_extract_grad (box_t ),
95+ t2j_extract_grad (params_t ))
96+ return t2j (energy ), grads
97+
98+ def potential_bwd (pairs , res , g ):
99+ return res [0 ]* g , res [1 ]* g , jax .tree .map (lambda x : x * g , res [2 ])
100+
101+ potential .defvjp (potential_fwd , potential_bwd )
75102
76103 return potential
77104
0 commit comments