Skip to content

orbax#1

Merged
jerryxyj merged 442 commits intomasterfrom
orbax
Feb 14, 2026
Merged

orbax#1
jerryxyj merged 442 commits intomasterfrom
orbax

Conversation

@jerryxyj
Copy link
Owner

No description provided.

shashaka and others added 30 commits September 29, 2025 10:55
* Add logaddexp2 method for ops

* rollback gptq_test

* correct code by gemini

* Add example for logaddexp2

* Correct example for logaddexp2

* correct excluded_concrete_tests.txt
* feat: numpy sort in openvino backend

* feat: included tests for numpy sort

* fix : code format and content ov dtype
…m#21667)

* feat: numpy median for openvino backend

* feat: included tests for numpy median

* fix: code format
A memory leak related to the executor in `CallbackList` was fixed in keras-team#20779

However, calling `Executor.shutdown` within `__del__` is intrisincally unsafe and can create deadlocks because the garbage collector can be called in different contexts.

This new approach uses the `on_train/test/predict_begin` and `on_train/test/predict_end` callbacks to detect when we're done with the executor.
- it counts the number of "begin"s and "end"s to handle the case of `evaluate` within `fit` (we do not shutdown the executor at the end of `evaluate` but instead keep it around for the rest of the training)
- it also handles `CallbackList` being reused between calls to `fit`, `evaluate` or `predict` even though Keras doesn't reuse.

Also renamed `_clear_futures` to `_flush_futures` to make it clear futures are not discarded, but exectuted.
…#21706)

`overwrite_with_gradient` would be ineffective on JAX in real-world conditions, i.e. within `model.fit`.

This is because in the training loop, `stateless_apply` is passed `trainable_variables` as arrays containing the values of the trainable variables, not the variables themselves. Instead, we have to inspect the variables.

`apply(grads)` without the `trainable_variables` argument passed in would not apply anything.

This is because the code uses `self._trainable_variables`. But this was an empty array for `LossScaleOptimizer`. This was fixed by adding `super().build(...)`.

Also fail when other arguments from the base optimizer are passed to `LossScaleOptimizer.__init__` since they are not actually supported. They are also no longer returned by `get_config`.
The `ProgBar` was using `backend.numpy.mean` causing it to use the accelerator (TPU or GPU) therefore causing a synchronization which defeated the async callback mechanism. This in turn was slowing down training.

All values provided in the `logs` are already Python floats and are all single values. There is therefore no need to do a `mean`, computing the average is simply a division of the running sum by the running count.
…zation()` (keras-team#21716)

* Fix the Doc of the combination relation in func keras.layers.Normalization()

* Update keras/src/layers/preprocessing/normalization.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…1719)

JAX uses `__jax_array__` to handle non-JAX types. For instance when doing `a * v` where `a` is a `jax.Array` and `v` is a `keras.Variable`, the `jax.Array.__mul__` implementation calls `v.__jax_array__()` because `v` is not a JAX type.

However, `__jax_array__` did not work in all contexts, and the next version of JAX further restricts which contexts it works in.

The fix rarely involves explictly calling `v.value`. Instead, we rely on existing mechanisms that are already in place to unwrap variables in a lot of contexts:
- ops are always supposed to call `convert_to_tensor` on tensor inputs and `convert_to_tensor` extracts values from variables
- using `keras.ops` instead of native ops (+ - * / < > & etc.) unwraps variables. It is already a best practice to use `keras.ops` instead of native ops:
    - to support the creation of functional models via `KerasTensor`s and their serialization
    - to have consistent type promotion between backends
    - to support sparse tensors and ragged tensors

This was tested via a seperate PR keras-team#21702 that won't be submitted because of https://github.com/keras-team/keras/pull/21702/files#diff-900deadc65fc119ce93fb813e340dcb644b8eab9e7c0207bf37cdc05b8e8796e .
Bumps the github-actions group with 6 updates:

| Package | From | To |
| --- | --- | --- |
| [actions/checkout](https://github.com/actions/checkout) | `4` | `5` |
| [actions/setup-python](https://github.com/actions/setup-python) | `5` | `6` |
| [actions/github-script](https://github.com/actions/github-script) | `7` | `8` |
| [ossf/scorecard-action](https://github.com/ossf/scorecard-action) | `2.4.2` | `2.4.3` |
| [github/codeql-action](https://github.com/github/codeql-action) | `3.29.7` | `3.30.5` |
| [actions/stale](https://github.com/actions/stale) | `9` | `10` |


Updates `actions/checkout` from 4 to 5
- [Release notes](https://github.com/actions/checkout/releases)
- [Commits](actions/checkout@v4...v5)

Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](actions/setup-python@v5...v6)

Updates `actions/github-script` from 7 to 8
- [Release notes](https://github.com/actions/github-script/releases)
- [Commits](actions/github-script@v7...v8)

Updates `ossf/scorecard-action` from 2.4.2 to 2.4.3
- [Release notes](https://github.com/ossf/scorecard-action/releases)
- [Changelog](https://github.com/ossf/scorecard-action/blob/main/RELEASE.md)
- [Commits](ossf/scorecard-action@05b42c6...4eaacf0)

Updates `github/codeql-action` from 3.29.7 to 3.30.5
- [Release notes](https://github.com/github/codeql-action/releases)
- [Changelog](https://github.com/github/codeql-action/blob/main/CHANGELOG.md)
- [Commits](github/codeql-action@51f7732...3599b3b)

Updates `actions/stale` from 9 to 10
- [Release notes](https://github.com/actions/stale/releases)
- [Changelog](https://github.com/actions/stale/blob/main/CHANGELOG.md)
- [Commits](actions/stale@v9...v10)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '5'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/github-script
  dependency-version: '8'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: ossf/scorecard-action
  dependency-version: 2.4.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: github-actions
- dependency-name: github/codeql-action
  dependency-version: 3.30.5
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: github-actions
- dependency-name: actions/stale
  dependency-version: '10'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
…eras-team#21613)

* Added logspace and linespace numpy implemation

* Minor changes

* Added Changes suggested by BOT

* Minor changes

* Minor changes

* Minor changes

* Minor Changes

* Minor Changes

* Performed the Suggested Changes

* Performed the Suggested Changes

---------

Co-authored-by: hertschuh <1091026+hertschuh@users.noreply.github.com>
* add jvp op

* add jvp op

* bug fix

* add symbolic call

* fix doc.
* add unfold op

* fix.

* fix error.

* fix error.

* fix error.

* fix jax and tf backend ,add numpy implement.

* fix document
…ayers.dot()` (keras-team#21718)

* Add the description that `0` should not in the arg `axes` in `keras.layers.dot()`

* Update keras/src/layers/merging/dot.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update dot.py to fix the doc of `Dot` and `batch_dot`

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…21566)

* add daily Python 3.13 CPU-only tests to nightly workflow

* dynamic job names in nightly workflow
* Fix histogram op for symbolic inputs

* Skip NumPy test

* Remove debugging statements

* Fix reason

* Fix jit_compile = True

* Skip NumPy tests

* Better test name

* Use scatter_nd

* Fix comment
…am#21734)

The new function is available starting in JAX v0.8.0, and the old experimental aliases are deprecated and will be removed in JAX v0.9.0.
…eam#21738)

* ensure eye behavior is consistent across backends

* Update keras/src/ops/numpy_test.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* simplify per pr review

* pre-commit

* fix test for torch backend + add comments

* update implementation to raise TypeError for consistency

* add case for M being the onl float

* improve naming of inner function for type check

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* add eye for openvino backend

* Update keras/src/backend/openvino/numpy.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* feature: include eye dtype

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…as-team#21732)

Replacement for keras-team#21704

Also:
- Disabled an ONNX export test for Torch that was already disabled on GPU with both JAX and Tensorflow.
- Moved install for `tf_keras` from `requirements.txt` to `action.yml` using the `--no-deps` option because `tf_keras` depends on `tensorflow`, which installs the non-CPU version of TensorFlow and causes issues with CPU tests.
* Add initial version

* Add tensorflow version

* Update numpy.py for ops

* Update numpy_test.py

* Add method for openvino

* clean code

* update code by gemini review

* update test case for non-complex type
* Remove the unused jax x64 context.

* Fix ops.trace dtype.
)

* correct implementation of `floor` and enable testing

* correct implementation for mean, min, max + consolidate duplicate logic

* correct impl for cumsum with bools

* make gemini corrections

* remove disabling of initial

* typo as i committed because incompetence

* fix dynamic shape behavior

* fix dynamic shape behavior
…team#21748)

* Sets is_gptq_calibrated flag when deserializing GPTQ models

* move flag initialization to load_own_variables

* Added tests
)

* fix absolute, enable all/any

* fix ceil func

* enable input/output tests for any/all, fix axis resolutoin

* fix sqrt for int32

* fix sqrt and sum, consolidate duplicated logic

* fix squeeze

* add pos/neg/regular inf

* fix nan, add finite support

* add a note about why we're not using openvino's existing capabilities for finite/infinite
…t=='channels_first'` (keras-team#21750)

* Fix the Bug in func `preprocess_input` when `x` in 3D and `data_format=='channels_first'`

* Update keras/src/applications/imagenet_utils.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…keras-team#21751)

in the `.keras` archive when they are initialized with a path to a vocabulary file. This makes the `.keras` archive fully self contained.

This was already the behavior when using either `set_vocabulary` or `adapt`. Simply, this behavior was extended to the case when `__init__` is called with a vocabulary file.

Note that this is technically a breaking change. Previously, upon doing `keras.saving.load_model`, it would be looking up the vocabulary file at the exact same path as when originally constructed.

Also disallow loading an arbitrary vocabulary file during model loading with `safe_mode=True` since the vocabulary file should now come from the archive.
Shi-pra-19 and others added 28 commits February 3, 2026 09:45
)

* Implement cross product operation for OpenVINO backend

* Refactor function to improve axis handling.

---------

Co-authored-by: Shipra <Shi-pra-19@users.noreply.github.com>
…ion (keras-team#22092)

* Fix conv symbolic execution to fail on empty output shapes

* Refactor conv symbolic invalid configuration tests

* Refactor conv symbolic invalid configuration tests

* Move conv output shape validation to compute_conv_output_shape

* Clarify error message for non-positive conv output shape
…nce (keras-team#22068)

* Fix Normalization broadcasting for scalar and multidim mean/variance

* fix: comment

* fix: handle multi-dim mean and variance via axis-aware expansion and added unit tests.

* fix:test for type-error

* fix: corrected right-to-left mean alignment, shape validation,added explicit broadcasting unit tests.

* fix: line duplicacy
…r. (keras-team#22085)

- Unify the way tests are skipped by turning `if ... pytest.skip()` into `@pytest.mark.skipif` wherever possible.
- Also do not skip the tests that actually pass.
…22107)

- `get_metrics_result` is called within `train_step`, which means that it gets jitted. In a jitted context, we should only manipulate backend tensors.
- `pythonify_logs` is already called within `CallbackList` before dispatching to the callbacks. `CallbackList` is also what takes care of calling `pythonify` logs asynchronously when possible.
- `pythonify_logs` was added for the last logs in `evaluate` for every backend to turn the logs to python floats before returning the result.
- `pythonify_logs` was used instead of `np.array` in `train_on_batch` and `test_on_batch` as `np.array` doesn't work out of the box on torch on GPU for instance.
…am#22054)

* Fix gaussian_blur padding calculation for even kernel sizes

The gaussian_blur function in the NumPy backend was using incorrect symmetric
padding that caused shape mismatches when convolving with even-sized kernels.

* Address PR review: Add test and fix kernel dimension ordering

- Add unit test for even kernel sizes (numpy backend specific)
- Fix kernel dimension ordering in _create_gaussian_kernel
- Fix asymmetric padding calculation in test helper

* Fixed CI failures: line length and use actual kernel shape

* Reformat `image_test.py`.

---------

Co-authored-by: SamareshSingh <ssam3003@gmail.com>
Co-authored-by: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com>
Variable initializers are considered for jitting when the variable is sharded.

- Fully replicated variables are now also considered for jitting.
- The size calculation for the threshold now takes into account the mesh size as a worst case scenario.
…eras-team#22122)

In the code that extracts the new variable values from the `StatelessScope`, `BaseOptimizer.stateless_apply` has logic to fallback to returning the variable itself if the variable value is not in the scope.

This is incorrect for two reasons:
- A variable is not a JAX object, `jax.jit` will fail if we ever return a variable.
- At that point, the variable has a `None` value and is not useable: https://github.com/keras-team/keras/blob/master/keras/src/backend/jax/trainer.py#L949-L951

Luckily, this is dead code because the scope has the full mapping of all the variables and this fallback is never used.
Co-authored-by: Shipra <Shi-pra-19@users.noreply.github.com>
* feat: add depth_to_space and space_to_depth ops

* Update keras/src/backend/jax/nn.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update keras/src/backend/numpy/nn.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update keras/src/backend/torch/nn.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fix: use explicit default for data_format in depth/space ops

* fix: update class defaults for depth/space ops consistency

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Numpy 2.4 removed support for `newshape` as the argument for `np.reshape`.
…in IndexLookup.load_assets (keras-team#22119)

* Fix vocabulary reload bug caused by trailing newline handling in IndexLookup.load_assets

* Fix vocabulary reload bug caused by trailing newline handling in IndexLookup.load_assets

* Fix vocabulary reload bug caused by trailing newline handling in IndexLookup.load_assets
keras-team#22154)

Also added verification that the shape of the inputs, the `shape` parameter and the `indices` have the same length.
Problem 1:
- the current validation does not allow functional models where some inputs are passed unmodified as outputs (this was allowed in Keras 2).

Problem 2:
- the current validation does not allow functional models where some inputs are not used (this was allowed in Keras 2).

Problem 3:
- the current validation only verifies that the inputs are used, nothing is verified on the outputs. In particular, it is possible to create a model with some outputs disconnected from the inputs, the construction will succeed. However, calling such a model will fail with an non-descript KeyError.

Solution:
- instead of the verifying how inputs are used, the new validation verifies that all the outputs are reachable from the inputs.

Also improved the error message by specifying which output is not connected to the inputs.
* Fix draw_bounding_boxes float32 to uint8 conversion

* fix: simplifed logic to ops.clip

* fix: handle empty array

* Update keras/src/visualization/draw_bounding_boxes.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Implement dstack function across all backends

* Refactor dstack function to improve dtype handling and add examples in docstring

---------

Co-authored-by: Shipra <Shi-pra-19@users.noreply.github.com>
Implements numpy.exp2 for calculating 2**x element-wise.
Enables exp2 tests by removing exclusions from excluded_concrete_tests.txt.
Implements numpy.trunc for truncation towards zero.
Enables trunc tests by removing exclusions from excluded_concrete_tests.txt.
…22130)

* fix: add missing validation for output padding < strides

* fix: add any()

* Update keras/src/layers/convolutional/base_conv_transpose.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fix: lint fix

* fix: invalid test output_padding stride combination

* fix: test fix

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…ras-team#22014)

* docs: Add guide on resuming training from weight-only checkpoints

* fix: Define lr_schedule in resume training example

* refactor: Move resume training guide to ModelCheckpoint docstring

Address reviewer feedback to move documentation from guides file
to ModelCheckpoint class docstring for better discoverability.
The guide will now appear in the API docs at keras.io/api/callbacks/model_checkpoint/

* style: Fix line length violation in ModelCheckpoint docstring

* fix: Update ModelCheckpoint docstring examples for consistency

- Add lr_schedule definition to first example
- Update model.compile() calls for consistency
- Fix model.fit() calls to include x_train, y_train and epochs=10
- Remove unused EPOCHS constant

Addresses reviewer feedback from hertschuh

* style: Fix line length violations in ModelCheckpoint docstring

- Break long model.fit() calls across multiple lines
- Comply with 80 character line limit
- Fixes CI formatting check

* Apply reviewer feedback to keras.device docstring

* Revert docstring changes to device function in backend/__init__.py

---------

Co-authored-by: Sikandar <ma5161310@gmail.com>
…eras-team#22142)

* Fix TrackedDict constructor to support iterable (key, value) inputs and add regression tests

* Fix TrackedDict constructor to support iterable (key, value) inputs and add regression tests
…eras-team#22155)

* Implement numpy.gcd using Euclidean algorithm for OpenVINO backend

* Fix scalar shape issue in numpy.gcd for OpenVINO backend

Co-authored-by: andersendsa <199610634+andersendsa@users.noreply.github.com>

* Implement numpy.gcd for OpenVINO backend with robust type and shape handling

Co-authored-by: andersendsa <199610634+andersendsa@users.noreply.github.com>

---------

Co-authored-by: andersendsa <199610634+andersendsa@users.noreply.github.com>
@jerryxyj jerryxyj merged commit 0922805 into master Feb 14, 2026
4 of 8 checks passed
@jerryxyj jerryxyj deleted the orbax branch February 14, 2026 09:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.