diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..861b598 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,115 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python application + +on: + push: + branches: + - master + pull_request: + paths-ignore: + - 'docs/**' + +permissions: + contents: read + +# https://stackoverflow.com/a/72408109/6388696 +# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-concurrency-to-cancel-any-in-progress-job-or-run +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + linting: + name: Run linting/pre-commit checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pip install pre-commit + - run: pre-commit --version + - run: pre-commit install + - run: pre-commit run --all-files + + unit_tests: + needs: [linting] + runs-on: ${{ matrix.platform }} + strategy: + max-parallel: 4 + matrix: + platform: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - name: Install poetry + run: pipx install poetry + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "poetry" + - name: Install dependencies + run: poetry install + - name: Run tests (CPU) + env: + JAX_PLATFORMS: cpu + run: poetry run pytest -v --cov=torch_jax_interop --cov-report=xml --cov-append + + - name: Store coverage report as an artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-reports-unit-tests-${{ matrix.platform }}-${{ matrix.python-version }} + path: ./coverage.xml + + gpu_tests: + needs: [unit_tests] + runs-on: self-hosted + strategy: + max-parallel: 1 + matrix: + python-version: ['3.12'] + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: "poetry" + - run: pip install poetry + - name: Install dependencies + run: poetry install + + - name: Test with pytest + run: poetry run pytest -v --cov=torch_jax_interop --cov-report=xml --cov-append + + - name: Store coverage report as an artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-reports-integration-tests-${{ matrix.python-version }} + path: ./coverage.xml + + # https://about.codecov.io/blog/uploading-code-coverage-in-a-separate-job-on-github-actions/ + upload-coverage-codecov: + needs: [unit_tests, gpu_tests] + runs-on: ubuntu-latest + name: Upload coverage reports to Codecov + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + pattern: coverage-reports-* + merge-multiple: false + # download all the artifacts in this directory (each .coverage.xml will be in a subdirectory) + # Next step if this doesn't work would be to give the coverage files a unique name and use merge-multiple: true + path: coverage_reports + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + directory: coverage_reports + fail_ci_if_error: true diff --git a/pyproject.toml b/pyproject.toml index db14b7f..741edba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,7 @@ addopts = ["--doctest-modules"] [tool.pytest_env] CUBLAS_WORKSPACE_CONFIG = ":4096:8" -[tool.ruff] line-length = 99 - [tool.docformatter] wrap-summaries = 99 wrap-descriptions = 99 diff --git a/torch_jax_interop/conftest.py b/torch_jax_interop/conftest.py index c49f31d..0e0ea69 100644 --- a/torch_jax_interop/conftest.py +++ b/torch_jax_interop/conftest.py @@ -192,6 +192,7 @@ def torch_network( torch_device: torch.device, ): torch_network_type: type[torch.nn.Module] = request.param + with ( torch_device, torch.random.fork_rng([torch_device] if torch_device.type == "cuda" else []), diff --git a/torch_jax_interop/to_jax.py b/torch_jax_interop/to_jax.py index 908e0ff..fe41724 100644 --- a/torch_jax_interop/to_jax.py +++ b/torch_jax_interop/to_jax.py @@ -34,33 +34,27 @@ @overload -def torch_to_jax(value: torch.Tensor, /) -> jax.Array: - ... +def torch_to_jax(value: torch.Tensor, /) -> jax.Array: ... @overload -def torch_to_jax(value: torch.device, /) -> jax.Device: - ... +def torch_to_jax(value: torch.device, /) -> jax.Device: ... @overload -def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]: - ... +def torch_to_jax(value: tuple[torch.Tensor, ...], /) -> tuple[jax.Array, ...]: ... @overload -def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]: - ... +def torch_to_jax(value: list[torch.Tensor], /) -> list[jax.Array]: ... @overload -def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]: - ... +def torch_to_jax(value: NestedDict[K, torch.Tensor], /) -> NestedDict[K, jax.Array]: ... @overload -def torch_to_jax(value: Any, /) -> Any: - ... +def torch_to_jax(value: Any, /) -> Any: ... def torch_to_jax(value: Any, /) -> Any: @@ -102,7 +96,11 @@ def _direct_conversion(v: torch.Tensor) -> jax.Array: def _to_from_dlpack( v: torch.Tensor, ignore_deprecation_warning: bool = True ) -> jax.Array: - with warnings.catch_warnings() if ignore_deprecation_warning else contextlib.nullcontext(): + with ( + warnings.catch_warnings() + if ignore_deprecation_warning + else contextlib.nullcontext() + ): # Only way to get this to work for CPU seems to be with to/from dlpack... so we have to use this deprecated # conversion method for now. # todo: Should we let it though though? @@ -144,7 +142,29 @@ def torch_to_jax_tensor(value: torch.Tensor) -> jax.Array: return _direct_conversion(value.flatten()).reshape(value.shape) try: - return _direct_conversion(value) + # Try using the "new" way to convert using from_dlpack directly + return jax_from_dlpack( + value, device=torch_to_jax_device(value.device), copy=None + ) + except AssertionError as err: + if not err.args[0].startswith("Unexpected XLA layout override"): + raise + # Some "AssertionError: Unexpected XLA layout override" + # Try using the "old" way to convert using from_dlpack of a dlpack tensor. + try: + dlpack = torch_to_dlpack(value) + return jax_from_dlpack(dlpack, copy=False) + except jaxlib.xla_extension.XlaRuntimeError as err: + log_once( + logger, + message=( + f"Unable to view tensor of shape {tuple(value.shape)} as a jax.Array in-place:\n" + f"'{err}'\n" + f"Tensors of this shape will be flattened and unflattened (which may or " + f"may not involve making a copy of the tensor's data)." + ), + level=logging.WARNING, + ) except jaxlib.xla_extension.XlaRuntimeError as err: log_once( logger, diff --git a/torch_jax_interop/to_jax_module.py b/torch_jax_interop/to_jax_module.py index b2b89d2..8ca0d65 100644 --- a/torch_jax_interop/to_jax_module.py +++ b/torch_jax_interop/to_jax_module.py @@ -26,9 +26,7 @@ def make_functional( module_with_state: Module[P, Out_cov], disable_autograd_tracking=False -) -> tuple[ - Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...] -]: +) -> tuple[Callable[Concatenate[Iterable[torch.Tensor], P], Out_cov], tuple[torch.Tensor, ...]]: """Backward compatibility equivalent for `functorch.make_functional` in the new torch.func API. Adapted from https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf as suggested by @@ -117,6 +115,42 @@ def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: Returns ------- the functional model and the model parameters (converted to jax arrays). + + ## Example + + >>> import torch + >>> import jax + >>> torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + >>> torch.manual_seed(0) # doctest:+ELLIPSIS + + >>> model = torch.nn.Linear(3, 2, device=torch_device) + >>> wrapped_model, params = torch_module_to_jax(model) + >>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: + ... y_pred = wrapped_model(params, x) + ... return jax.numpy.mean((y - y_pred) ** 2) + >>> x = jax.random.uniform(key=jax.random.key(0), shape=(1, 3)) + >>> y = jax.random.uniform(key=jax.random.key(1), shape=(1, 1)) + >>> loss, grad = jax.value_and_grad(loss_function)(params, x, y) + >>> loss # doctest: +SKIP + Array(0.05772376, dtype=float32) + >>> grad # doctest: +SKIP + (Array([[-0.32541627, -0.10608128, -0.2133986 ], + [-0.04103044, -0.01337536, -0.02690658]], dtype=float32), Array([-0.33710665, -0.04250443], dtype=float32)) + + To use `jax.jit` on the model, you need to pass an example of an output so we can + tell the JIT compiler the output shapes and dtypes to expect: + + >>> # here we reuse the same model as before: + >>> wrapped_model, params = torch_module_to_jax(model, example_output=torch.zeros(1, 2, device=torch_device)) + >>> def loss_function(params, x: jax.Array, y: jax.Array) -> jax.Array: + ... y_pred = wrapped_model(params, x) + ... return jax.numpy.mean((y - y_pred) ** 2) + >>> loss, grad = jax.jit(jax.value_and_grad(loss_function))(params, x, y) + >>> loss # doctest: +SKIP + Array(0.05772376, dtype=float32) + >>> grad # doctest: +SKIP + (Array([[-0.32541627, -0.10608128, -0.2133986 ], + [-0.04103044, -0.01337536, -0.02690658]], dtype=float32), Array([-0.33710665, -0.04250443], dtype=float32)) """ if example_output is not None: @@ -153,8 +187,7 @@ def apply(params, *args, **kwargs): # Apply the model function to the input data. if example_output is None: if any( - isinstance(v, jax.core.Tracer) - for v in jax.tree.leaves((params, args, kwargs)) + isinstance(v, jax.core.Tracer) for v in jax.tree.leaves((params, args, kwargs)) ): raise RuntimeError( "You need to pass `example_output` in order to JIT the torch function!" diff --git a/torch_jax_interop/utils.py b/torch_jax_interop/utils.py index 32a71de..206ed40 100644 --- a/torch_jax_interop/utils.py +++ b/torch_jax_interop/utils.py @@ -11,6 +11,7 @@ def log_once(logger: logging.Logger, message: str, level: int): logger.log(level=level, msg=message, stacklevel=2) +# NOTE: Done like this to preserve the original function signature. log_once = functools.cache(log_once)