Skip to content

Commit 196f6fe

Browse files
committed
support multiple inputs that require grad
1 parent fdf94c1 commit 196f6fe

3 files changed

Lines changed: 5 additions & 4 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## jax2torch
22

3-
Use Jax functions in Pytorch with DLPack, as outlined <a href="https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9">in a gist</a> by <a href="https://github.com/mattjj">@mattjj</a>. Right now only supports one tensor input (with optional non-tensor input arguments) to one tensor output, for the purposes of <a href="https://github.com/spetti/SMURF">differentiable alignment</a>.
3+
Use Jax functions in Pytorch with DLPack, as outlined <a href="https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9">in a gist</a> by <a href="https://github.com/mattjj">@mattjj</a>. The repository was made for the purposes of making the <a href="https://github.com/spetti/SMURF">differentiable alignment</a> work here interoperable with Pytorch.
44

55
## Install
66

jax2torch/jax2torch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def backward(ctx, *grad_args):
4343
grad_args = tree_t2j(grad_args)
4444
else:
4545
grad_args = t2j(grad_args[0])
46-
grads, *_ = ctx.fun_vjp(grad_args)
47-
ret = tree_j2t(grads), *((None,) * (ctx.num_args - 1))
46+
grads = ctx.fun_vjp(grad_args)
47+
grads = tuple(map(lambda t: t if isinstance(t, jnp.ndarray) else None, grads))
48+
ret = tree_j2t(grads)
4849
return ret
4950

5051
sig = signature(fn)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'jax2torch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.6',
6+
version = '0.0.7',
77
license='MIT',
88
description = 'Jax 2 Torch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)