Skip to content

Commit 1abe53e

Browse files
add nnx guide
1 parent 0f66e3c commit 1abe53e

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed

guides/keras_nnx_guide.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
"""
2+
Title: How to use Keras with NNX backend
3+
Author: [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli)
4+
Date created: 2025/08/07
5+
Last modified: 2025/08/07
6+
Description: How to use Keras with NNX backend
7+
Accelerator: CPU
8+
"""
9+
10+
"""
11+
12+
# A Guide to the Keras & Flax NNX Integration
13+
14+
This tutorial will guide you through the integration of Keras with Flax's NNX
15+
(Neural Networks JAX) module system, demonstrating how it significantly
16+
enhances variable handling and opens up advanced training capabilities within
17+
the JAX ecosystem. Whether you love the simplicity of model.fit() or the
18+
fine-grained control of a custom training loop, this integration lets you have
19+
the best of both worlds. Let's dive in!
20+
21+
# Why Keras and NNX Integration?
22+
23+
Keras is known for its user-friendliness and high-level API, making deep
24+
learning accessible. JAX, on the other hand, provides high-performance
25+
numerical computation, especially suited for machine learning research due to
26+
its JIT compilation and automatic differentiation capabilities. NNX is Flax's
27+
functional module system built on JAX, offering explicit state management and
28+
powerful functional programming paradigms
29+
30+
NNX is designed for simplicity. It is characterized by its Pythonic approach,
31+
where modules are standard Python classes, promoting ease of use and
32+
familiarity. NNX prioritizes user-friendliness and offers fine-grained control
33+
over JAX transformations through typed Variable collections.
34+
35+
The integration of Keras with NNX allows you to leverage the best of both
36+
worlds: the simplicity and modularity of Keras for model construction,
37+
combined with the power and explicit control of NNX and JAX for variable
38+
management and sophisticated training loops.
39+
40+
# Getting Started: Setting Up Your Environment
41+
```
42+
pip install -q -U keras
43+
pip install -U -q flax==0.11.0
44+
```
45+
"""
46+
47+
"""# Enabling NNX Mode
48+
49+
To activate the integration, we must set two environment variables before
50+
importing Keras. This tells Keras to use the JAX backend and switch to NNX as
51+
an opt in feature.
52+
"""
53+
54+
import os
55+
56+
os.environ["KERAS_BACKEND"] = "jax"
57+
os.environ["KERAS_NNX_ENABLED"] = "true"
58+
from flax import nnx
59+
import keras
60+
import jax.numpy as jnp
61+
62+
print("✅ Keras is now running on JAX with NNX enabled!")
63+
64+
"""# The Core Integration: Keras Variables in NNX
65+
66+
The heart of this integration is the new keras.Variable, which is designed to
67+
be a native citizen of the Flax NNX ecosystem. This means you can mix Keras
68+
and NNX components freely, and NNX's tracing and state management tools will
69+
understand your Keras variables.
70+
Let's prove it. We'll create an nnx.Module that contains both a standard
71+
nnx.Linear layer and a keras.Variable.
72+
"""
73+
74+
from keras import Variable as KerasVariable
75+
76+
77+
class MyNnxModel(nnx.Module):
78+
def __init__(self, rngs):
79+
self.linear = nnx.Linear(2, 3, rngs=rngs)
80+
self.custom_variable = KerasVariable(jnp.ones((1, 3)))
81+
82+
def __call__(self, x):
83+
return self.linear(x) + self.custom_variable
84+
85+
86+
# Instantiate the model
87+
model = MyNnxModel(rngs=nnx.Rngs(0))
88+
89+
# --- Verification ---
90+
# 1. Is the KerasVariable traced by NNX?
91+
print(f"✅ Traced: {hasattr(model.custom_variable, '_trace_state')}")
92+
93+
# 2. Does NNX see the KerasVariable in the model's state?
94+
print("✅ Variables:", nnx.variables(model))
95+
96+
# 3. Can we access its value directly?
97+
print("✅ Value:", model.custom_variable.value)
98+
99+
"""What this shows:
100+
The KerasVariable is successfully traced by NNX, just like any native
101+
nnx.Variable.
102+
The nnx.variables() function correctly identifies and lists our
103+
custom_variable as part of the model's state.
104+
This confirms that Keras state and NNX state can live together in perfect
105+
harmony.
106+
107+
# The Best of Both Worlds: Training Workflows
108+
109+
Now for the exciting part: training models. This integration unlocks two
110+
powerful workflows.
111+
112+
## Workflow 1: The Classic Keras Experience (model.fit)
113+
"""
114+
115+
import numpy as np
116+
117+
"""
118+
1. Create a Keras Model
119+
"""
120+
model = keras.Sequential(
121+
[keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer")]
122+
)
123+
124+
print("--- Initial Model Weights ---")
125+
initial_weights = model.get_weights()
126+
print(f"Initial Kernel: {initial_weights[0].T}") # .T for better display
127+
print(f"Initial Bias: {initial_weights[1]}")
128+
129+
"""
130+
2. Create Dummy Data
131+
"""
132+
X_dummy = np.random.rand(100, 10)
133+
y_dummy = np.random.rand(100, 1)
134+
"""
135+
3. Compile and Fit
136+
"""
137+
model.compile(
138+
optimizer=keras.optimizers.SGD(learning_rate=0.01),
139+
loss="mean_squared_error",
140+
)
141+
142+
print("\n--- Training with model.fit() ---")
143+
history = model.fit(X_dummy, y_dummy, epochs=5, batch_size=32, verbose=1)
144+
145+
"""
146+
4. Verify a change
147+
"""
148+
print("\n--- Weights After Training ---")
149+
updated_weights = model.get_weights()
150+
print(f"Updated Kernel: {updated_weights[0].T}")
151+
print(f"Updated Bias: {updated_weights[1]}")
152+
153+
# Verification
154+
if not np.array_equal(initial_weights[1], updated_weights[1]):
155+
print("\n✅ SUCCESS: Model variables were updated during training.")
156+
else:
157+
print("\n❌ FAILURE: Model variables were not updated.")
158+
159+
"""As you can see, your existing Keras code works out-of-the-box, giving you a
160+
high-level, productive experience powered by JAX and NNX under the hood.
161+
162+
## Workflow 2: The Power of NNX: Custom Training Loops
163+
164+
For maximum flexibility, you can treat any Keras layer or model as an
165+
nnx.Module and write your own training loop using libraries like Optax.
166+
This is perfect when you need fine-grained control over the gradient and
167+
update process.
168+
"""
169+
170+
import numpy as np
171+
import optax
172+
173+
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
174+
Y = 0.8 * X + 0.1 + np.random.normal(0, 0.1, size=X.shape)
175+
176+
177+
class MySimpleKerasModel(keras.Model):
178+
def __init__(self, **kwargs):
179+
super().__init__(**kwargs)
180+
# Define the layers of your model
181+
self.dense_layer = keras.layers.Dense(1)
182+
183+
def call(self, inputs):
184+
# Define the forward pass
185+
# The 'inputs' argument will receive the input tensor when the model is
186+
# called
187+
return self.dense_layer(inputs)
188+
189+
190+
model = MySimpleKerasModel()
191+
model(X)
192+
193+
tx = optax.sgd(1e-3)
194+
trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable)
195+
optimizer = nnx.Optimizer(model, tx, wrt=trainable_var)
196+
197+
198+
@nnx.jit
199+
def train_step(model, optimizer, batch):
200+
x, y = batch
201+
202+
def loss_fn(model_):
203+
y_pred = model_(x)
204+
return jnp.mean((y - y_pred) ** 2)
205+
206+
diff_state = nnx.DiffState(0, trainable_var)
207+
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
208+
optimizer.update(model, grads)
209+
210+
211+
@nnx.jit
212+
def test_step(model, batch):
213+
x, y = batch
214+
y_pred = model(x)
215+
loss = jnp.mean((y - y_pred) ** 2)
216+
return {"loss": loss}
217+
218+
219+
def dataset(batch_size=10):
220+
while True:
221+
idx = np.random.choice(len(X), size=batch_size)
222+
yield X[idx], Y[idx]
223+
224+
225+
for step, batch in enumerate(dataset()):
226+
train_step(model, optimizer, batch)
227+
228+
if step % 100 == 0:
229+
logs = test_step(model, (X, Y))
230+
print(f"step: {step}, loss: {logs['loss']}")
231+
232+
if step >= 500:
233+
break
234+
235+
"""This example shows how a keras model object is seamlessly passed to
236+
nnx.Optimizer and differentiated by nnx.grad. This composition allows you
237+
to integrate Keras components into sophisticated JAX/NNX workflows. This
238+
approach also works perfectly with sequential, functional, subclassed keras
239+
models are even just layers.
240+
241+
# Saving and Loading
242+
243+
Your investment in the Keras ecosystem is safe. Standard features like model
244+
serialization work exactly as you'd expect.
245+
"""
246+
247+
# Create a simple model
248+
model = keras.Sequential([keras.layers.Dense(units=1, input_shape=(10,))])
249+
dummy_input = np.random.rand(1, 10)
250+
251+
# Test call
252+
print("Original model output:", model(dummy_input))
253+
254+
# Save and load
255+
model.save("my_nnx_model.keras")
256+
restored_model = keras.models.load_model("my_nnx_model.keras")
257+
258+
print("Restored model output:", restored_model(dummy_input))
259+
260+
# Verification
261+
np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input))
262+
print("\n✅ SUCCESS: Restored model output matches original model output.")
263+
264+
"""# Real-World Application: Training Gemma
265+
266+
Before trying out this KerasHub model, please make sure you have set up your
267+
Kaggle credentials in colab secrets. The colab pulls in `KAGGLE_KEY` and
268+
`KAGGLE_USERNAME` to authenticate and download the models.
269+
"""
270+
271+
import keras_hub
272+
273+
# Set a float16 policy for memory efficiency
274+
keras.config.set_dtype_policy("float16")
275+
276+
# Load Gemma from KerasHub
277+
gemma_lm = keras_hub.models.GemmaCausalLM.from_preset(
278+
"gemma_1.1_instruct_2b_en"
279+
)
280+
281+
# --- 1. Inference / Generation ---
282+
print("--- Gemma Generation ---")
283+
output = gemma_lm.generate("Keras is a", max_length=30)
284+
print(output)
285+
286+
# --- 2. Fine-tuning ---
287+
print("\n--- Gemma Fine-tuning ---")
288+
# Dummy data for demonstration
289+
features = np.array(["The quick brown fox jumped.", "I forgot my homework."])
290+
# The model.fit() API works seamlessly!
291+
gemma_lm.fit(x=features, batch_size=2)
292+
print("\n✅ Gemma fine-tuning step completed successfully!")
293+
294+
"""# Conclusion
295+
296+
The Keras-NNX integration represents a significant step forward, offering a
297+
unified framework for both rapid prototyping and high-performance,
298+
customizable research. You can now:
299+
Use familiar Keras APIs (Sequential, Model, fit, save) on a JAX backend.
300+
Integrate Keras layers and models directly into Flax NNX modules and training
301+
loops.Integrate keras code/model with NNX ecosytem like Qwix, Tunix, etc.
302+
Leverage the entire JAX ecosystem (e.g., nnx.jit, optax) with your Keras models.
303+
Seamlessly work with large models from KerasHub.
304+
"""

0 commit comments

Comments
 (0)