Skip to content

Commit 073c52c

Browse files
committed
flx interop
1 parent a29a2a5 commit 073c52c

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

torchax/test/test_flax.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import unittest
2+
import torch
23
import torchax
34
from flax import linen as nn
45
from torchax.flax import FlaxNNModule
6+
from torchax.interop import jax_jit
57
import jax.numpy as jnp
68
import jax
79

@@ -12,6 +14,7 @@ def test_flax_simple(self):
1214

1315
class CNN(nn.Module):
1416
"""A simple CNN model."""
17+
1518
@nn.compact
1619
def __call__(self, x):
1720
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
@@ -34,13 +37,50 @@ def __call__(self, x):
3437
expected = flax_model.apply(state, inputs)
3538

3639
env = torchax.default_env()
37-
nn_module = FlaxNNModule(env, flax_model, (inputs, ), {} )
40+
nn_module = FlaxNNModule(env, flax_model, (inputs,), {})
3841
res = nn_module.forward(inputs)
3942

4043
self.assertTrue(jnp.allclose(res.jax(), expected))
4144

42-
45+
def test_flax_functional_call(self):
46+
47+
class CNN(nn.Module):
48+
"""A simple CNN model."""
49+
50+
@nn.compact
51+
def __call__(self, x):
52+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
53+
x = nn.relu(x)
54+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
55+
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
56+
x = nn.relu(x)
57+
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
58+
x = x.reshape((x.shape[0], -1)) # flatten
59+
x = nn.Dense(features=256)(x)
60+
x = nn.relu(x)
61+
x = nn.Dense(features=10)(x)
62+
return x
63+
64+
flax_model = CNN()
65+
66+
inputs = jnp.ones((1, 28, 28, 1))
67+
env = torchax.default_env()
68+
state = flax_model.init(env.prng_key, inputs)
69+
expected = flax_model.apply(state, inputs)
70+
71+
env = torchax.default_env()
72+
nn_module = FlaxNNModule(env, flax_model, (inputs,), {})
73+
74+
@jax_jit
75+
def jitted(weights, args):
76+
return torch.func.functional_call(nn_module, weights, args)
77+
78+
with env:
79+
inputs_torch = torch.ones((1, 28, 28, 1), device='jax')
80+
state_dict = nn_module.state_dict()
81+
res = jitted(state_dict, inputs_torch)
82+
self.assertTrue(jnp.allclose(res.jax(), expected))
83+
84+
4385
if __name__ == '__main__':
4486
unittest.main()
45-
46-

torchax/torchax/flax.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,33 @@
77

88
class FlaxNNModule(torch.nn.Module):
99

10-
def __init__(self, env, flax_module, sample_args, sample_kwargs):
10+
def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
1111
super().__init__()
1212
prng = env.prng_key
13-
self._params = tx.interop.call_jax(flax_module.init, prng, *sample_args, **sample_kwargs)
13+
sample_kwargs = sample_kwargs or {}
14+
parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
15+
**sample_kwargs)
16+
17+
self._params = self._encode_nested_dict(parameter_dict)
18+
1419
self._flax_module = flax_module
15-
16-
def forward(self, *args, **kwargs):
17-
return tx.interop.call_jax(self._flax_module.apply, self._params, *args, **kwargs)
1820

19-
21+
def _encode_nested_dict(self, nested_dict):
22+
child_module = torch.nn.Module()
23+
for k, v in nested_dict.items():
24+
if isinstance(v, dict):
25+
child_module.add_module(k, self._encode_nested_dict(v))
26+
else:
27+
child_module.register_parameter(k, torch.nn.Parameter(v))
28+
return child_module
29+
30+
def _decode_nested_dict(self, child_module):
31+
result = dict(child_module.named_parameters(recurse=False))
32+
for k, v in child_module.named_children():
33+
result[k] = self._decode_nested_dict(v)
34+
return result
35+
36+
def forward(self, *args, **kwargs):
37+
nested_dict_params = self._decode_nested_dict(self._params)
38+
return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
39+
*args, **kwargs)

0 commit comments

Comments
 (0)