Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,23 +0,0 @@
.DS_Store
*.pyc
.vscode-test
__pycache__
**/.vscode-test/**
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
.pytest_cache
tmp/**
.vs/
dist/**
**/*.egg-info/*
.vscode
examples/**/*.jpg
.python-version
.coverage
*coverage.xml
.ruff_cache
87 changes: 87 additions & 0 deletions CUSTOM_GRADIENT_JAX_FIX.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Fix for custom_gradient with JAX backend and Variables

## Issue
GitHub Issue [#21105](https://github.com/keras-team/keras/issues/21105)

When using `@ops.custom_gradient` with the JAX backend, passing Keras Variables as arguments would cause a `TypeError: 'NoneType' object is not callable` during training. This occurred because JAX's `custom_gradient` would capture the Variable object itself instead of extracting its underlying tensor value.

## Root Cause
The JAX backend's `custom_gradient` function was directly wrapping `jax.custom_gradient` without converting Variable objects to their values, unlike the `stop_gradient` function which already handled this correctly.

## Solution
Modified `keras/src/backend/jax/core.py` to add a wrapper that automatically extracts `.value` from Variable objects before passing them to the user's custom gradient function. This is done using `tree.map_structure` to recursively handle nested structures.

### Changes Made

**File: `keras/src/backend/jax/core.py`**

```python
def custom_gradient(fun):
def wrapper(*args, **kwargs):
# Convert Variable objects to their values
def _convert_arg(arg):
if isinstance(arg, Variable):
return arg.value
return arg

args = tree.map_structure(_convert_arg, args)
kwargs = tree.map_structure(_convert_arg, kwargs)
return fun(*args, **kwargs)

return jax.custom_gradient(fun=wrapper)
```

**File: `keras/src/ops/core_test.py`**

Added `test_custom_gradient_with_variable()` to verify that Variables can be passed directly to custom_gradient functions without needing to manually add `.value`.

## Testing

### Run the specific test:
```bash
pytest keras/src/ops/core_test.py::CoreOpsCorrectnessTest::test_custom_gradient_with_variable -v
```

### Run all core ops tests:
```bash
pytest keras/src/ops/core_test.py -v
```

## Example Usage

Before the fix, you needed to manually extract `.value`:

```python
@ops.custom_gradient
def roundpass(x, log_scaling):
scaling = ops.exp(log_scaling)
rounded = ops.round(x * scaling) / scaling

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args
return upstream, ops.zeros_like(log_scaling)

return rounded, grad

class QuantizedLayer(layers.Layer):
def call(self, x):
# Workaround: manually add .value
return roundpass(x, self.log_scaling.value)
```

After the fix, Variables work directly:

```python
class QuantizedLayer(layers.Layer):
def call(self, x):
# Works automatically now!
return roundpass(x, self.log_scaling)
```

## Impact
- ✅ Fixes the TypeError when Variables are passed to custom_gradient functions
- ✅ Makes JAX backend behavior consistent with user expectations
- ✅ Aligns with how `stop_gradient` already handles Variables
- ✅ Backward compatible - existing code using `.value` workaround still works
- ✅ No performance impact - conversion happens once at function decoration time
133 changes: 0 additions & 133 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,133 +0,0 @@
# Keras 3: Deep Learning for Humans

Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only).
Effortlessly build and train models for computer vision, natural language processing, audio processing,
timeseries forecasting, recommender systems, etc.

- **Accelerated model development**: Ship deep learning solutions faster thanks to the high-level UX of Keras
and the availability of easy-to-debug runtimes like PyTorch or JAX eager execution.
- **State-of-the-art performance**: By picking the backend that is the fastest for your model architecture (often JAX!),
leverage speedups ranging from 20% to 350% compared to other frameworks. [Benchmark here](https://keras.io/getting_started/benchmarks/).
- **Datacenter-scale training**: Scale confidently from your laptop to large clusters of GPUs or TPUs.

Join nearly three million developers, from burgeoning startups to global enterprises, in harnessing the power of Keras 3.


## Installation

### Install with pip

Keras 3 is available on PyPI as `keras`. Note that Keras 2 remains available as the `tf-keras` package.

1. Install `keras`:

```
pip install keras --upgrade
```

2. Install backend package(s).

To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. Additionally,
The `openvino` backend is available with support for model inference only.

### Local installation

#### Minimal installation

Keras 3 is compatible with Linux and macOS systems. For Windows users, we recommend using WSL2 to run Keras.
To install a local development version:

1. Install dependencies:

```
pip install -r requirements.txt
```

2. Run installation command from the root directory.

```
python pip_build.py --install
```

3. Run API generation script when creating PRs that update `keras_export` public APIs:

```
./shell/api_gen.sh
```

## Backend Compatibility Table

The following table lists the minimum supported versions of each backend for the latest stable release of Keras (v3.x):

| Backend | Minimum Supported Version |
|------------|---------------------------|
| TensorFlow | 2.16.1 |
| JAX | 0.4.20 |
| PyTorch | 2.1.0 |
| OpenVINO | 2025.3.0 |

#### Adding GPU support

The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also
provide a separate `requirements-{backend}-cuda.txt` for TensorFlow, JAX, and PyTorch. These install all CUDA
dependencies via `pip` and expect a NVIDIA driver to be pre-installed. We recommend a clean Python environment for each
backend to avoid CUDA version mismatches. As an example, here is how to create a JAX GPU environment with `conda`:

```shell
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
```

## Configuring your backend

You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json`
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example:

```
export KERAS_BACKEND="jax"
```

In Colab, you can do:

```python
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
```

**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after
the package has been imported.

**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model
predictions using `model.predict()` method.

## Backwards compatibility

Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your
existing `tf.keras` code, make sure that your calls to `model.save()` are using the up-to-date `.keras` format, and you're
done.

If your `tf.keras` model does not include custom components, you can start running it on top of JAX or PyTorch immediately.

If it does include custom components (e.g. custom layers or a custom `train_step()`), it is usually possible to convert it
to a backend-agnostic implementation in just a few minutes.

In addition, Keras models can consume datasets in any format, regardless of the backend you're using:
you can train your models with your existing `tf.data.Dataset` pipelines or PyTorch `DataLoaders`.

## Why use Keras 3?

- Run your high-level Keras workflows on top of any framework -- benefiting at will from the advantages of each framework,
e.g. the scalability and performance of JAX or the production ecosystem options of TensorFlow.
- Write custom components (e.g. layers, models, metrics) that you can use in low-level workflows in any framework.
- You can take a Keras model and train it in a training loop written from scratch in native TF, JAX, or PyTorch.
- You can take a Keras model and use it as part of a PyTorch-native `Module` or as part of a JAX-native model function.
- Make your ML code future-proof by avoiding framework lock-in.
- As a PyTorch user: get access to power and usability of Keras, at last!
- As a JAX user: get access to a fully-featured, battle-tested, well-documented modeling and training library.


Read more in the [Keras 3 release announcement](https://keras.io/keras_3/).
56 changes: 56 additions & 0 deletions TORCH_JIT_COMPILE_LIMITATIONS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Torch Backend jit_compile Limitations

## Issue #21647: jit_compile=True with EfficientNetV2 on torch backend

### Problem
When using `jit_compile=True` with certain Keras models (especially EfficientNetV2) on the torch backend, you may encounter `InternalTorchDynamoError` or `RuntimeError` related to torch.compile being unable to trace optree operations.

### Root Cause
Keras uses tree operations (from optree or torch._pytree) for handling nested structures. When `jit_compile=True` is enabled, PyTorch's torch.compile attempts to trace through all Python operations, including these tree utilities. However, torch.compile has limitations with certain C/C++ extensions and symbolic operations.

### Error Messages
- **GPU**: `InternalTorchDynamoError: TypeError: '<' not supported between instances of 'NoneType' and 'int'`
- **CPU**: `RuntimeError: TypeError: cannot determine truth value of Relational`

### Workarounds

#### Option 1: Disable JIT Compilation (Recommended)
```python
model.compile(
optimizer=Adam(learning_rate=0.001),
loss=CategoricalCrossentropy(),
metrics=['accuracy'],
jit_compile=False # or omit this parameter
)
```

#### Option 2: Use a Different Backend
Switch to TensorFlow or JAX backend which have better jit_compile support:
```python
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax"
```

#### Option 3: Use Fixed Input Shapes
If you must use jit_compile with torch, ensure all input shapes are fixed (no None dimensions):
```python
base_model = EfficientNetV2B2(
include_top=False,
input_shape=(224, 224, 3), # Fixed shape, no None
pooling='avg',
weights=None
)
```

### Status
This is a known limitation of torch.compile when working with complex nested structures. The PyTorch team is aware of limitations with certain patterns and continues to improve torch.compile support.

### Related Issues
- PyTorch Issue: torch.compile limitations with pytree operations
- Keras Issue #21647

### Future Improvements
Potential solutions being explored:
1. Add torch.compile skip decorators for tree operations
2. Use torch.compiler.disable() context for specific operations
3. Refactor to use pure torch operations where possible
Loading
Loading