Skip to content

Commit 83f00b1

Browse files
authored
Make compatible with more recent Jax versions (#11)
* Make compatible with more recent Jax versions Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Upgrade to cuda13 and fix torch/setuptools bug Signed-off-by: Fabrice Normandin <normandf@mila.quebec> * Fix mkdocs upgrade issue Signed-off-by: Fabrice Normandin <normandf@mila.quebec> --------- Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
1 parent fbbf998 commit 83f00b1

7 files changed

Lines changed: 1334 additions & 1070 deletions

File tree

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
max-parallel: 4
6161
matrix:
6262
platform: ["ubuntu-latest", "macos-latest"]
63-
python-version: ["3.12"]
63+
python-version: ["3.11", "3.12", "3.13"]
6464
steps:
6565
- uses: actions/checkout@v4
6666
- name: Install the latest version of uv

mkdocs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ plugins:
1818
- mkdocstrings:
1919
handlers:
2020
python:
21-
import:
21+
inventories:
2222
- https://docs.python.org/3/objects.inv
2323
- https://docs.pytest.org/en/stable/objects.inv
2424
- https://flax.readthedocs.io/en/latest/objects.inv

pyproject.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@ readme = "README.md"
55
authors = [
66
{ name = "Fabrice Normandin", email = "fabrice.normandin@gmail.com" },
77
]
8-
requires-python = ">=3.10"
9-
dependencies = ["jax>=0.4.28", "torch>=2.0.0"]
8+
requires-python = ">=3.11"
9+
dependencies = [
10+
"jax>=0.6.0",
11+
"torch",
12+
# note: This is because of a weird bug where torch wants setuptools to build cpp extensions (seems related to torch.compile).
13+
"setuptools; python_version == '3.11'",
14+
]
1015
dynamic = ["version"]
1116

1217
[project.optional-dependencies]
13-
gpu = ["jax[cuda12]>=0.4.28; sys_platform == 'linux'"]
18+
gpu = ["jax[cuda13]; sys_platform == 'linux'"]
1419

1520

1621
[dependency-groups]

torch_jax_interop/to_jax.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,10 @@
1111

1212
import jax
1313
import jax.core
14-
import jaxlib
15-
import jaxlib.xla_extension
1614
import torch
1715
import torch.func
1816
import torch.utils._pytree
1917
from jax.dlpack import from_dlpack as jax_from_dlpack # type: ignore
20-
from torch.utils.dlpack import to_dlpack as torch_to_dlpack # type: ignore
2118

2219
from .types import (
2320
Dataclass,
@@ -34,33 +31,27 @@
3431

3532

3633
@overload
37-
def torch_to_jax(value: torch.Tensor, /) -> jax.Array:
38-
...
34+
def torch_to_jax(value: torch.Tensor, /) -> jax.Array: ...
3935

4036

4137
@overload
42-
def torch_to_jax(value: torch.device, /) -> jax.Device:
43-
...
38+
def torch_to_jax(value: torch.device, /) -> jax.Device: ...
4439

4540

4641
@overload
47-
def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]:
48-
...
42+
def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]: ...
4943

5044

5145
@overload
52-
def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]:
53-
...
46+
def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]: ...
5447

5548

5649
@overload
57-
def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]:
58-
...
50+
def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]: ...
5951

6052

6153
@overload
62-
def torch_to_jax(value: Any, /) -> Any:
63-
...
54+
def torch_to_jax(value: Any, /) -> Any: ...
6455

6556

6657
def torch_to_jax(value: Any, /) -> Any:
@@ -99,16 +90,14 @@ def _direct_conversion(v: torch.Tensor) -> jax.Array:
9990
return jax_from_dlpack(v, copy=False)
10091

10192

102-
def _to_from_dlpack(
103-
v: torch.Tensor, ignore_deprecation_warning: bool = True
104-
) -> jax.Array:
93+
def _to_from_dlpack(v: torch.Tensor, ignore_deprecation_warning: bool = True) -> jax.Array:
10594
with warnings.catch_warnings() if ignore_deprecation_warning else contextlib.nullcontext():
10695
# Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated
10796
# conversion method for now.
10897
# todo: Should we let it though though?
10998
if ignore_deprecation_warning:
11099
warnings.filterwarnings("ignore", category=DeprecationWarning)
111-
return jax_from_dlpack(torch_to_dlpack(v), copy=False)
100+
return jax_from_dlpack(v, copy=False)
112101

113102

114103
def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
@@ -130,7 +119,7 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
130119
# return _direct_conversion(value)
131120
return _to_from_dlpack(value, ignore_deprecation_warning=True)
132121

133-
except jaxlib.xla_extension.XlaRuntimeError as err:
122+
except RuntimeError as err:
134123
log_once(
135124
logger,
136125
message=(
@@ -145,7 +134,7 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array:
145134

146135
try:
147136
return _direct_conversion(value)
148-
except jaxlib.xla_extension.XlaRuntimeError as err:
137+
except RuntimeError as err:
149138
log_once(
150139
logger,
151140
message=(

torch_jax_interop/to_jax_module.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626

2727
def make_functional(
2828
module_with_state: Module[P, Out_cov], disable_autograd_tracking=False
29-
) -> tuple[
30-
Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...]
31-
]:
29+
) -> tuple[Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...]]:
3230
"""Backward compatibility equivalent for `functorch.make_functional` in the new torch.func API.
3331
3432
Adapted from https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf as suggested by
@@ -128,15 +126,21 @@ def j2t(v: JaxPyTree) -> TorchPyTree:
128126
if any(isinstance(v_i, jax.core.Tracer) for v_i in jax.tree.leaves(v)):
129127
# running inside JIT.
130128
return jax.pure_callback(
131-
functools.partial(jax.tree.map, jax_to_torch), v, v, vectorized=True
129+
functools.partial(jax.tree.map, jax_to_torch),
130+
v,
131+
v,
132+
vmap_method="legacy_vectorized",
132133
)
133134
return jax.tree.map(jax_to_torch, v)
134135

135136
def t2j(v: TorchPyTree) -> JaxPyTree:
136137
if any(isinstance(v_i, jax.core.Tracer) for v_i in jax.tree.leaves(v)):
137138
# running inside JIT.
138139
return jax.pure_callback(
139-
functools.partial(jax.tree.map, torch_to_jax), v, v, vectorized=True
140+
functools.partial(jax.tree.map, torch_to_jax),
141+
v,
142+
v,
143+
vmap_method="legacy_vectorized",
140144
)
141145
return jax.tree.map(torch_to_jax, v)
142146

@@ -153,8 +157,7 @@ def apply(params, *args, **kwargs):
153157
# Apply the model function to the input data.
154158
if example_output is None:
155159
if any(
156-
isinstance(v, jax.core.Tracer)
157-
for v in jax.tree.leaves((params, args, kwargs))
160+
isinstance(v, jax.core.Tracer) for v in jax.tree.leaves((params, args, kwargs))
158161
):
159162
raise RuntimeError(
160163
"You need to pass `example_output` in order to JIT the torch function!"
@@ -186,7 +189,7 @@ def pytorch_model_callback(params, *args, **kwargs):
186189
params,
187190
*args,
188191
**kwargs,
189-
vectorized=True,
192+
vmap_method="legacy_vectorized",
190193
)
191194
# Convert the output data from JAX to PyTorch representations
192195
out = t2j(out)
@@ -224,7 +227,7 @@ def _pytorch_model_backward_callback(params, grads, *args, **kwargs):
224227
grads,
225228
*args,
226229
**kwargs,
227-
vectorized=True,
230+
vmap_method="legacy_vectorized",
228231
)
229232
in_grads = t2j(in_grads)
230233
return in_grads

torch_jax_interop/to_torch.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import jax
1111
import torch
12-
from jax.dlpack import to_dlpack as jax_to_dlpack # type: ignore (not exported there?)
1312
from torch.utils import dlpack as torch_dlpack
1413

1514
from .types import Dataclass, DataclassType, K, NestedDict, NestedMapping
@@ -19,33 +18,27 @@
1918

2019

2120
@overload
22-
def jax_to_torch(value: jax.Array, /) -> torch.Tensor:
23-
...
21+
def jax_to_torch(value: jax.Array, /) -> torch.Tensor: ...
2422

2523

2624
@overload
27-
def jax_to_torch(value: jax.Device, /) -> torch.device:
28-
...
25+
def jax_to_torch(value: jax.Device, /) -> torch.device: ...
2926

3027

3128
@overload
32-
def jax_to_torch(value: tuple[jax.Array, ...], /) -> tuple[torch.Tensor, ...]:
33-
...
29+
def jax_to_torch(value: tuple[jax.Array, ...], /) -> tuple[torch.Tensor, ...]: ...
3430

3531

3632
@overload
37-
def jax_to_torch(value: list[jax.Array], /) -> list[torch.Tensor]:
38-
...
33+
def jax_to_torch(value: list[jax.Array], /) -> list[torch.Tensor]: ...
3934

4035

4136
@overload
42-
def jax_to_torch(value: NestedDict[K, jax.Array], /) -> NestedDict[K, torch.Tensor]:
43-
...
37+
def jax_to_torch(value: NestedDict[K, jax.Array], /) -> NestedDict[K, torch.Tensor]: ...
4438

4539

4640
@overload
47-
def jax_to_torch(value: Any, /) -> Any:
48-
...
41+
def jax_to_torch(value: Any, /) -> Any: ...
4942

5043

5144
def jax_to_torch(value: Any, /) -> Any:
@@ -88,7 +81,7 @@ def jax_to_torch_tensor(value: jax.Array, /) -> torch.Tensor:
8881
try:
8982
return torch_dlpack.from_dlpack(value)
9083
except Exception:
91-
return torch_dlpack.from_dlpack(jax_to_dlpack(value))
84+
return torch_dlpack.from_dlpack(value.__dlpack__())
9285

9386

9487
# Register it like this so the type hints are preserved on the functions (which are also called

0 commit comments

Comments
 (0)