|
| 1 | +""" |
| 2 | +Title: How to use Keras with NNX backend |
| 3 | +Author: [Divyashree Sreepathihalli](https://github.com/divyashreepathihalli) |
| 4 | +Date created: 2025/08/07 |
| 5 | +Last modified: 2025/08/07 |
| 6 | +Description: How to use Keras with NNX backend |
| 7 | +Accelerator: CPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +
|
| 12 | +# A Guide to the Keras & Flax NNX Integration |
| 13 | +
|
| 14 | +This tutorial will guide you through the integration of Keras with Flax's NNX |
| 15 | +(Neural Networks JAX) module system, demonstrating how it significantly |
| 16 | +enhances variable handling and opens up advanced training capabilities within |
| 17 | +the JAX ecosystem. Whether you love the simplicity of model.fit() or the |
| 18 | +fine-grained control of a custom training loop, this integration lets you have |
| 19 | +the best of both worlds. Let's dive in! |
| 20 | +
|
| 21 | +# Why Keras and NNX Integration? |
| 22 | +
|
| 23 | +Keras is known for its user-friendliness and high-level API, making deep |
| 24 | +learning accessible. JAX, on the other hand, provides high-performance |
| 25 | +numerical computation, especially suited for machine learning research due to |
| 26 | +its JIT compilation and automatic differentiation capabilities. NNX is Flax's |
| 27 | +functional module system built on JAX, offering explicit state management and |
| 28 | +powerful functional programming paradigms |
| 29 | +
|
| 30 | +NNX is designed for simplicity. It is characterized by its Pythonic approach, |
| 31 | +where modules are standard Python classes, promoting ease of use and |
| 32 | +familiarity. NNX prioritizes user-friendliness and offers fine-grained control |
| 33 | +over JAX transformations through typed Variable collections. |
| 34 | +
|
| 35 | +The integration of Keras with NNX allows you to leverage the best of both |
| 36 | +worlds: the simplicity and modularity of Keras for model construction, |
| 37 | +combined with the power and explicit control of NNX and JAX for variable |
| 38 | +management and sophisticated training loops. |
| 39 | +
|
| 40 | +# Getting Started: Setting Up Your Environment |
| 41 | +``` |
| 42 | +pip install -q -U keras |
| 43 | +pip install -U -q flax==0.11.0 |
| 44 | +``` |
| 45 | +""" |
| 46 | + |
| 47 | +"""# Enabling NNX Mode |
| 48 | +
|
| 49 | +To activate the integration, we must set two environment variables before |
| 50 | +importing Keras. This tells Keras to use the JAX backend and switch to NNX as |
| 51 | +an opt in feature. |
| 52 | +""" |
| 53 | + |
| 54 | +import os |
| 55 | + |
| 56 | +os.environ["KERAS_BACKEND"] = "jax" |
| 57 | +os.environ["KERAS_NNX_ENABLED"] = "true" |
| 58 | +from flax import nnx |
| 59 | +import keras |
| 60 | +import jax.numpy as jnp |
| 61 | + |
| 62 | +print("✅ Keras is now running on JAX with NNX enabled!") |
| 63 | + |
| 64 | +"""# The Core Integration: Keras Variables in NNX |
| 65 | +
|
| 66 | +The heart of this integration is the new keras.Variable, which is designed to |
| 67 | +be a native citizen of the Flax NNX ecosystem. This means you can mix Keras |
| 68 | +and NNX components freely, and NNX's tracing and state management tools will |
| 69 | +understand your Keras variables. |
| 70 | +Let's prove it. We'll create an nnx.Module that contains both a standard |
| 71 | +nnx.Linear layer and a keras.Variable. |
| 72 | +""" |
| 73 | + |
| 74 | +from keras import Variable as KerasVariable |
| 75 | + |
| 76 | + |
| 77 | +class MyNnxModel(nnx.Module): |
| 78 | + def __init__(self, rngs): |
| 79 | + self.linear = nnx.Linear(2, 3, rngs=rngs) |
| 80 | + self.custom_variable = KerasVariable(jnp.ones((1, 3))) |
| 81 | + |
| 82 | + def __call__(self, x): |
| 83 | + return self.linear(x) + self.custom_variable |
| 84 | + |
| 85 | + |
| 86 | +# Instantiate the model |
| 87 | +model = MyNnxModel(rngs=nnx.Rngs(0)) |
| 88 | + |
| 89 | +# --- Verification --- |
| 90 | +# 1. Is the KerasVariable traced by NNX? |
| 91 | +print(f"✅ Traced: {hasattr(model.custom_variable, '_trace_state')}") |
| 92 | + |
| 93 | +# 2. Does NNX see the KerasVariable in the model's state? |
| 94 | +print("✅ Variables:", nnx.variables(model)) |
| 95 | + |
| 96 | +# 3. Can we access its value directly? |
| 97 | +print("✅ Value:", model.custom_variable.value) |
| 98 | + |
| 99 | +"""What this shows: |
| 100 | +The KerasVariable is successfully traced by NNX, just like any native |
| 101 | +nnx.Variable. |
| 102 | +The nnx.variables() function correctly identifies and lists our |
| 103 | +custom_variable as part of the model's state. |
| 104 | +This confirms that Keras state and NNX state can live together in perfect |
| 105 | +harmony. |
| 106 | +
|
| 107 | +# The Best of Both Worlds: Training Workflows |
| 108 | +
|
| 109 | +Now for the exciting part: training models. This integration unlocks two |
| 110 | +powerful workflows. |
| 111 | +
|
| 112 | +## Workflow 1: The Classic Keras Experience (model.fit) |
| 113 | +""" |
| 114 | + |
| 115 | +import numpy as np |
| 116 | + |
| 117 | +""" |
| 118 | +1. Create a Keras Model |
| 119 | +""" |
| 120 | +model = keras.Sequential( |
| 121 | + [keras.layers.Dense(units=1, input_shape=(10,), name="my_dense_layer")] |
| 122 | +) |
| 123 | + |
| 124 | +print("--- Initial Model Weights ---") |
| 125 | +initial_weights = model.get_weights() |
| 126 | +print(f"Initial Kernel: {initial_weights[0].T}") # .T for better display |
| 127 | +print(f"Initial Bias: {initial_weights[1]}") |
| 128 | + |
| 129 | +""" |
| 130 | +2. Create Dummy Data |
| 131 | +""" |
| 132 | +X_dummy = np.random.rand(100, 10) |
| 133 | +y_dummy = np.random.rand(100, 1) |
| 134 | +""" |
| 135 | +3. Compile and Fit |
| 136 | +""" |
| 137 | +model.compile( |
| 138 | + optimizer=keras.optimizers.SGD(learning_rate=0.01), |
| 139 | + loss="mean_squared_error", |
| 140 | +) |
| 141 | + |
| 142 | +print("\n--- Training with model.fit() ---") |
| 143 | +history = model.fit(X_dummy, y_dummy, epochs=5, batch_size=32, verbose=1) |
| 144 | + |
| 145 | +""" |
| 146 | +4. Verify a change |
| 147 | +""" |
| 148 | +print("\n--- Weights After Training ---") |
| 149 | +updated_weights = model.get_weights() |
| 150 | +print(f"Updated Kernel: {updated_weights[0].T}") |
| 151 | +print(f"Updated Bias: {updated_weights[1]}") |
| 152 | + |
| 153 | +# Verification |
| 154 | +if not np.array_equal(initial_weights[1], updated_weights[1]): |
| 155 | + print("\n✅ SUCCESS: Model variables were updated during training.") |
| 156 | +else: |
| 157 | + print("\n❌ FAILURE: Model variables were not updated.") |
| 158 | + |
| 159 | +"""As you can see, your existing Keras code works out-of-the-box, giving you a |
| 160 | +high-level, productive experience powered by JAX and NNX under the hood. |
| 161 | +
|
| 162 | +## Workflow 2: The Power of NNX: Custom Training Loops |
| 163 | +
|
| 164 | +For maximum flexibility, you can treat any Keras layer or model as an |
| 165 | +nnx.Module and write your own training loop using libraries like Optax. |
| 166 | +This is perfect when you need fine-grained control over the gradient and |
| 167 | +update process. |
| 168 | +""" |
| 169 | + |
| 170 | +import numpy as np |
| 171 | +import optax |
| 172 | + |
| 173 | +X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None] |
| 174 | +Y = 0.8 * X + 0.1 + np.random.normal(0, 0.1, size=X.shape) |
| 175 | + |
| 176 | + |
| 177 | +class MySimpleKerasModel(keras.Model): |
| 178 | + def __init__(self, **kwargs): |
| 179 | + super().__init__(**kwargs) |
| 180 | + # Define the layers of your model |
| 181 | + self.dense_layer = keras.layers.Dense(1) |
| 182 | + |
| 183 | + def call(self, inputs): |
| 184 | + # Define the forward pass |
| 185 | + # The 'inputs' argument will receive the input tensor when the model is |
| 186 | + # called |
| 187 | + return self.dense_layer(inputs) |
| 188 | + |
| 189 | + |
| 190 | +model = MySimpleKerasModel() |
| 191 | +model(X) |
| 192 | + |
| 193 | +tx = optax.sgd(1e-3) |
| 194 | +trainable_var = nnx.All(keras.Variable, lambda path, x: x.trainable) |
| 195 | +optimizer = nnx.Optimizer(model, tx, wrt=trainable_var) |
| 196 | + |
| 197 | + |
| 198 | +@nnx.jit |
| 199 | +def train_step(model, optimizer, batch): |
| 200 | + x, y = batch |
| 201 | + |
| 202 | + def loss_fn(model_): |
| 203 | + y_pred = model_(x) |
| 204 | + return jnp.mean((y - y_pred) ** 2) |
| 205 | + |
| 206 | + diff_state = nnx.DiffState(0, trainable_var) |
| 207 | + grads = nnx.grad(loss_fn, argnums=diff_state)(model) |
| 208 | + optimizer.update(model, grads) |
| 209 | + |
| 210 | + |
| 211 | +@nnx.jit |
| 212 | +def test_step(model, batch): |
| 213 | + x, y = batch |
| 214 | + y_pred = model(x) |
| 215 | + loss = jnp.mean((y - y_pred) ** 2) |
| 216 | + return {"loss": loss} |
| 217 | + |
| 218 | + |
| 219 | +def dataset(batch_size=10): |
| 220 | + while True: |
| 221 | + idx = np.random.choice(len(X), size=batch_size) |
| 222 | + yield X[idx], Y[idx] |
| 223 | + |
| 224 | + |
| 225 | +for step, batch in enumerate(dataset()): |
| 226 | + train_step(model, optimizer, batch) |
| 227 | + |
| 228 | + if step % 100 == 0: |
| 229 | + logs = test_step(model, (X, Y)) |
| 230 | + print(f"step: {step}, loss: {logs['loss']}") |
| 231 | + |
| 232 | + if step >= 500: |
| 233 | + break |
| 234 | + |
| 235 | +"""This example shows how a keras model object is seamlessly passed to |
| 236 | +nnx.Optimizer and differentiated by nnx.grad. This composition allows you |
| 237 | +to integrate Keras components into sophisticated JAX/NNX workflows. This |
| 238 | +approach also works perfectly with sequential, functional, subclassed keras |
| 239 | +models are even just layers. |
| 240 | +
|
| 241 | +# Saving and Loading |
| 242 | +
|
| 243 | +Your investment in the Keras ecosystem is safe. Standard features like model |
| 244 | +serialization work exactly as you'd expect. |
| 245 | +""" |
| 246 | + |
| 247 | +# Create a simple model |
| 248 | +model = keras.Sequential([keras.layers.Dense(units=1, input_shape=(10,))]) |
| 249 | +dummy_input = np.random.rand(1, 10) |
| 250 | + |
| 251 | +# Test call |
| 252 | +print("Original model output:", model(dummy_input)) |
| 253 | + |
| 254 | +# Save and load |
| 255 | +model.save("my_nnx_model.keras") |
| 256 | +restored_model = keras.models.load_model("my_nnx_model.keras") |
| 257 | + |
| 258 | +print("Restored model output:", restored_model(dummy_input)) |
| 259 | + |
| 260 | +# Verification |
| 261 | +np.testing.assert_allclose(model(dummy_input), restored_model(dummy_input)) |
| 262 | +print("\n✅ SUCCESS: Restored model output matches original model output.") |
| 263 | + |
| 264 | +"""# Real-World Application: Training Gemma |
| 265 | +
|
| 266 | +Before trying out this KerasHub model, please make sure you have set up your |
| 267 | +Kaggle credentials in colab secrets. The colab pulls in `KAGGLE_KEY` and |
| 268 | +`KAGGLE_USERNAME` to authenticate and download the models. |
| 269 | +""" |
| 270 | + |
| 271 | +import keras_hub |
| 272 | + |
| 273 | +# Set a float16 policy for memory efficiency |
| 274 | +keras.config.set_dtype_policy("float16") |
| 275 | + |
| 276 | +# Load Gemma from KerasHub |
| 277 | +gemma_lm = keras_hub.models.GemmaCausalLM.from_preset( |
| 278 | + "gemma_1.1_instruct_2b_en" |
| 279 | +) |
| 280 | + |
| 281 | +# --- 1. Inference / Generation --- |
| 282 | +print("--- Gemma Generation ---") |
| 283 | +output = gemma_lm.generate("Keras is a", max_length=30) |
| 284 | +print(output) |
| 285 | + |
| 286 | +# --- 2. Fine-tuning --- |
| 287 | +print("\n--- Gemma Fine-tuning ---") |
| 288 | +# Dummy data for demonstration |
| 289 | +features = np.array(["The quick brown fox jumped.", "I forgot my homework."]) |
| 290 | +# The model.fit() API works seamlessly! |
| 291 | +gemma_lm.fit(x=features, batch_size=2) |
| 292 | +print("\n✅ Gemma fine-tuning step completed successfully!") |
| 293 | + |
| 294 | +"""# Conclusion |
| 295 | +
|
| 296 | +The Keras-NNX integration represents a significant step forward, offering a |
| 297 | +unified framework for both rapid prototyping and high-performance, |
| 298 | +customizable research. You can now: |
| 299 | +Use familiar Keras APIs (Sequential, Model, fit, save) on a JAX backend. |
| 300 | +Integrate Keras layers and models directly into Flax NNX modules and training |
| 301 | +loops.Integrate keras code/model with NNX ecosytem like Qwix, Tunix, etc. |
| 302 | +Leverage the entire JAX ecosystem (e.g., nnx.jit, optax) with your Keras models. |
| 303 | +Seamlessly work with large models from KerasHub. |
| 304 | +""" |
0 commit comments