Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python application

on:
push:
branches:
- master
pull_request:
paths-ignore:
- 'docs/**'

permissions:
contents: read

# https://stackoverflow.com/a/72408109/6388696
# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-concurrency-to-cancel-any-in-progress-job-or-run
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
linting:
name: Run linting/pre-commit checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pip install pre-commit
- run: pre-commit --version
- run: pre-commit install
- run: pre-commit run --all-files

unit_tests:
needs: [linting]
runs-on: ${{ matrix.platform }}
strategy:
max-parallel: 4
matrix:
platform: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Install poetry
run: pipx install poetry
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "poetry"
- name: Install dependencies
run: poetry install
- name: Run tests (CPU)
env:
JAX_PLATFORMS: cpu
run: poetry run pytest -v --cov=torch_jax_interop --cov-report=xml --cov-append

- name: Store coverage report as an artifact
uses: actions/upload-artifact@v4
with:
name: coverage-reports-unit-tests-${{ matrix.platform }}-${{ matrix.python-version }}
path: ./coverage.xml

gpu_tests:
needs: [unit_tests]
runs-on: self-hosted
strategy:
max-parallel: 1
matrix:
python-version: ['3.12']
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "poetry"
- run: pip install poetry
- name: Install dependencies
run: poetry install

- name: Test with pytest
run: poetry run pytest -v --cov=torch_jax_interop --cov-report=xml --cov-append

- name: Store coverage report as an artifact
uses: actions/upload-artifact@v4
with:
name: coverage-reports-integration-tests-${{ matrix.python-version }}
path: ./coverage.xml

# https://about.codecov.io/blog/uploading-code-coverage-in-a-separate-job-on-github-actions/
upload-coverage-codecov:
needs: [unit_tests, gpu_tests]
runs-on: ubuntu-latest
name: Upload coverage reports to Codecov
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Download artifacts
uses: actions/download-artifact@v4
with:
pattern: coverage-reports-*
merge-multiple: false
# download all the artifacts in this directory (each .coverage.xml will be in a subdirectory)
# Next step if this doesn't work would be to give the coverage files a unique name and use merge-multiple: true
path: coverage_reports
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
directory: coverage_reports
fail_ci_if_error: true
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ addopts = ["--doctest-modules"]
[tool.pytest_env]
CUBLAS_WORKSPACE_CONFIG = ":4096:8"

[tool.ruff]
line-length = 99

[tool.docformatter]
wrap-summaries = 99
wrap-descriptions = 99
Expand Down
1 change: 1 addition & 0 deletions torch_jax_interop/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def torch_network(
torch_device: torch.device,
):
torch_network_type: type[torch.nn.Module] = request.param

with (
torch_device,
torch.random.fork_rng([torch_device] if torch_device.type == "cuda" else []),
Expand Down
48 changes: 34 additions & 14 deletions torch_jax_interop/to_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,27 @@


@overload
def torch_to_jax(value: torch.Tensor, /) -> jax.Array:
...
def torch_to_jax(value: torch.Tensor, /) -> jax.Array: ...


@overload
def torch_to_jax(value: torch.device, /) -> jax.Device:
...
def torch_to_jax(value: torch.device, /) -> jax.Device: ...


@overload
def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]:
...
def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]: ...


@overload
def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]:
...
def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]: ...


@overload
def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]:
...
def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]: ...


@overload
def torch_to_jax(value: Any, /) -> Any:
...
def torch_to_jax(value: Any, /) -> Any: ...


def torch_to_jax(value: Any, /) -> Any:
Expand Down Expand Up @@ -102,7 +96,11 @@
def _to_from_dlpack(
v: torch.Tensor, ignore_deprecation_warning: bool = True
) -> jax.Array:
with warnings.catch_warnings() if ignore_deprecation_warning else contextlib.nullcontext():
with (
warnings.catch_warnings()
if ignore_deprecation_warning
else contextlib.nullcontext()
):
# Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated
# conversion method for now.
# todo: Should we let it though though?
Expand Down Expand Up @@ -144,7 +142,29 @@
return _direct_conversion(value.flatten()).reshape(value.shape)

try:
return _direct_conversion(value)
# Try using the "new" way to convert using from_dlpack directly
return jax_from_dlpack(

Check warning on line 146 in torch_jax_interop/to_jax.py

View check run for this annotation

Codecov / codecov/patch

torch_jax_interop/to_jax.py#L146

Added line #L146 was not covered by tests
value, device=torch_to_jax_device(value.device), copy=None
)
except AssertionError as err:
if not err.args[0].startswith("Unexpected XLA layout override"):
raise

Check warning on line 151 in torch_jax_interop/to_jax.py

View check run for this annotation

Codecov / codecov/patch

torch_jax_interop/to_jax.py#L149-L151

Added lines #L149 - L151 were not covered by tests
# Some "AssertionError: Unexpected XLA layout override"
# Try using the "old" way to convert using from_dlpack of a dlpack tensor.
try:
dlpack = torch_to_dlpack(value)
return jax_from_dlpack(dlpack, copy=False)
except jaxlib.xla_extension.XlaRuntimeError as err:
log_once(

Check warning on line 158 in torch_jax_interop/to_jax.py

View check run for this annotation

Codecov / codecov/patch

torch_jax_interop/to_jax.py#L154-L158

Added lines #L154 - L158 were not covered by tests
logger,
message=(
f"Unable to view tensor of shape {tuple(value.shape)} as a jax.Array in-place:\n"
f"'{err}'\n"
f"Tensors of this shape will be flattened and unflattened (which may or "
f"may not involve making a copy of the tensor's data)."
),
level=logging.WARNING,
)
except jaxlib.xla_extension.XlaRuntimeError as err:
log_once(
logger,
Expand Down
43 changes: 38 additions & 5 deletions torch_jax_interop/to_jax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@

def make_functional(
module_with_state: Module[P, Out_cov], disable_autograd_tracking=False
) -> tuple[
Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...]
]:
) -> tuple[Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...]]:
"""Backward compatibility equivalent for `functorch.make_functional` in the new torch.func API.

Adapted from https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf as suggested by
Expand Down Expand Up @@ -117,6 +115,42 @@ def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
Returns
-------
the functional model and the model parameters (converted to jax arrays).

## Example

>>> import torch
>>> import jax
>>> torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> torch.manual_seed(0) # doctest:+ELLIPSIS
<torch._C.Generator object at ...>
>>> model = torch.nn.Linear(3, 2, device=torch_device)
>>> wrapped_model, params = torch_module_to_jax(model)
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
... y_pred = wrapped_model(params, x)
... return jax.numpy.mean((y - y_pred) ** 2)
>>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3))
>>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1))
>>> loss, grad = jax.value_and_grad(loss_function)(params, x, y)
>>> loss # doctest: +SKIP
Array(0.05772376, dtype=float32)
>>> grad # doctest: +SKIP
(Array([[-0.32541627, -0.10608128, -0.2133986 ],
[-0.04103044, -0.01337536, -0.02690658]], dtype=float32), Array([-0.33710665, -0.04250443], dtype=float32))

To use `jax.jit` on the model, you need to pass an example of an output so we can
tell the JIT compiler the output shapes and dtypes to expect:

>>> # here we reuse the same model as before:
>>> wrapped_model, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device=torch_device))
>>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array:
... y_pred = wrapped_model(params, x)
... return jax.numpy.mean((y - y_pred) ** 2)
>>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y)
>>> loss # doctest: +SKIP
Array(0.05772376, dtype=float32)
>>> grad # doctest: +SKIP
(Array([[-0.32541627, -0.10608128, -0.2133986 ],
[-0.04103044, -0.01337536, -0.02690658]], dtype=float32), Array([-0.33710665, -0.04250443], dtype=float32))
"""

if example_output is not None:
Expand Down Expand Up @@ -153,8 +187,7 @@ def apply(params, *args, **kwargs):
# Apply the model function to the input data.
if example_output is None:
if any(
isinstance(v, jax.core.Tracer)
for v in jax.tree.leaves((params, args, kwargs))
isinstance(v, jax.core.Tracer) for v in jax.tree.leaves((params, args, kwargs))
):
raise RuntimeError(
"You need to pass `example_output` in order to JIT the torch function!"
Expand Down
1 change: 1 addition & 0 deletions torch_jax_interop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def log_once(logger: logging.Logger, message: str, level: int):
logger.log(level=level, msg=message, stacklevel=2)


# NOTE: Done like this to preserve the original function signature.
log_once = functools.cache(log_once)


Expand Down
Loading