Skip to content

Add interop with flax (Part 1) #9176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchax/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
torch==2.6.0 ; sys_platform == 'darwin' # macOS
torch==2.6.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU
yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml`
flax==0.10.6
97 changes: 97 additions & 0 deletions torchax/test/test_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import unittest
import torch
import torchax
from flax import linen as nn
from torchax.flax import FlaxNNModule
from torchax.interop import jax_jit
import jax.numpy as jnp
import jax


class CNN(nn.Module):
"""A simple CNN model."""

@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x


class FlaxTest(unittest.TestCase):

def test_flax_simple(self):
flax_model = CNN()

inputs = jnp.ones((1, 28, 28, 1))
env = torchax.default_env()
state = flax_model.init(env.prng_key, inputs)
expected = flax_model.apply(state, inputs)

env = torchax.default_env()
nn_module = FlaxNNModule(env, flax_model, (inputs,), {})
res = nn_module.forward(inputs)

self.assertTrue(jnp.allclose(res.jax(), expected))

def test_flax_functional_call(self):
flax_model = CNN()

inputs = jnp.ones((1, 28, 28, 1))
env = torchax.default_env()
state = flax_model.init(env.prng_key, inputs)
expected = flax_model.apply(state, inputs)

env = torchax.default_env()
nn_module = FlaxNNModule(env, flax_model, (inputs,), {})

@jax_jit
def jitted(weights, args):
return torch.func.functional_call(nn_module, weights, args)

with env:
inputs_torch = torch.ones((1, 28, 28, 1), device='jax')
state_dict = nn_module.state_dict()
res = jitted(state_dict, inputs_torch)
self.assertTrue(jnp.allclose(res.jax(), expected))

def test_flax_module_nested(self):
env = torchax.default_env()

class Parent(torch.nn.Module):

def __init__(self):
super().__init__()
self.a = torch.nn.Linear(28, 28)
sample_cnn_inputs = torch.ones((1, 28, 28, 1), device='jax')
self.cnn = FlaxNNModule(env, CNN(), (sample_cnn_inputs,), {})

def forward(self, x):
y = self.a(x)
y = y.reshape((-1, 28, 28, 1))
res = self.cnn(y)
return res

with env:
nn_module = Parent()

@jax_jit
def jitted(weights, args):
return torch.func.functional_call(nn_module, weights, args)

inputs_torch = torch.ones((1, 28, 28), device='jax')
state_dict = nn_module.state_dict()
res = jitted(state_dict, inputs_torch)
print(res)


if __name__ == '__main__':
unittest.main()
39 changes: 39 additions & 0 deletions torchax/torchax/flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Flax interop."""

import torch
import torchax as tx
import torchax.interop


class FlaxNNModule(torch.nn.Module):

def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
super().__init__()
prng = env.prng_key
sample_kwargs = sample_kwargs or {}
parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
**sample_kwargs)

self._params = self._encode_nested_dict(parameter_dict)

self._flax_module = flax_module

def _encode_nested_dict(self, nested_dict):
child_module = torch.nn.Module()
for k, v in nested_dict.items():
if isinstance(v, dict):
child_module.add_module(k, self._encode_nested_dict(v))
else:
child_module.register_parameter(k, torch.nn.Parameter(v))
return child_module

def _decode_nested_dict(self, child_module):
result = dict(child_module.named_parameters(recurse=False))
for k, v in child_module.named_children():
result[k] = self._decode_nested_dict(v)
return result

def forward(self, *args, **kwargs):
nested_dict_params = self._decode_nested_dict(self._params)
return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
*args, **kwargs)