1
1
import unittest
2
+ import torch
2
3
import torchax
3
4
from flax import linen as nn
4
5
from torchax .flax import FlaxNNModule
6
+ from torchax .interop import jax_jit
5
7
import jax .numpy as jnp
6
8
import jax
7
9
@@ -12,6 +14,7 @@ def test_flax_simple(self):
12
14
13
15
class CNN (nn .Module ):
14
16
"""A simple CNN model."""
17
+
15
18
@nn .compact
16
19
def __call__ (self , x ):
17
20
x = nn .Conv (features = 32 , kernel_size = (3 , 3 ))(x )
@@ -34,13 +37,48 @@ def __call__(self, x):
34
37
expected = flax_model .apply (state , inputs )
35
38
36
39
env = torchax .default_env ()
37
- nn_module = FlaxNNModule (env , flax_model , (inputs , ), {} )
40
+ nn_module = FlaxNNModule (env , flax_model , (inputs ,), {})
38
41
res = nn_module .forward (inputs )
39
42
40
43
self .assertTrue (jnp .allclose (res .jax (), expected ))
41
44
42
-
45
+ def test_flax_functional_call (self ):
46
+
47
+ class CNN (nn .Module ):
48
+ """A simple CNN model."""
49
+
50
+ @nn .compact
51
+ def __call__ (self , x ):
52
+ x = nn .Conv (features = 32 , kernel_size = (3 , 3 ))(x )
53
+ x = nn .relu (x )
54
+ x = nn .avg_pool (x , window_shape = (2 , 2 ), strides = (2 , 2 ))
55
+ x = nn .Conv (features = 64 , kernel_size = (3 , 3 ))(x )
56
+ x = nn .relu (x )
57
+ x = nn .avg_pool (x , window_shape = (2 , 2 ), strides = (2 , 2 ))
58
+ x = x .reshape ((x .shape [0 ], - 1 )) # flatten
59
+ x = nn .Dense (features = 256 )(x )
60
+ x = nn .relu (x )
61
+ x = nn .Dense (features = 10 )(x )
62
+ return x
63
+
64
+ flax_model = CNN ()
65
+
66
+ inputs = jnp .ones ((1 , 28 , 28 , 1 ))
67
+ env = torchax .default_env ()
68
+ state = flax_model .init (env .prng_key , inputs )
69
+ expected = flax_model .apply (state , inputs )
70
+
71
+ env = torchax .default_env ()
72
+ nn_module = FlaxNNModule (env , flax_model , (inputs ,), {})
73
+
74
+ @jax_jit
75
+ def jitted (weights , args ):
76
+ return torch .func .functional_call (nn_module , weights , args )
77
+
78
+ state_dict = nn_module .state_dict ()
79
+ res = jitted (state_dict , inputs )
80
+ self .assertTrue (jnp .allclose (res .jax (), expected ))
81
+
82
+
43
83
if __name__ == '__main__' :
44
84
unittest .main ()
45
-
46
-
0 commit comments