Skip to content

v0.4.0

Latest

Choose a tag to compare

@sachinprasadhs sachinprasadhs released this 12 Jan 19:49
ef98322

New features

  • keras_rs.losses.ListMLELoss: Added support for the ListMLE loss function.

New example

  • DLRM-DCNv2: Added implementation of the DLRM-DCNv2 model architecture with the MLPerf performance benchmark using the Criteo dataset.

Bug fixes & Improvements

  • DistributedEmbedding on JAX
    • Updated DistributedEmbedding to work with the latest jax-tpu-embedding and take advantage of the latest performance improvements, in particular with table stacking.
    • Bug fixes and improvements with the synchronization of table statistics across multiple hosts.
    • Made table statistics updates optional.
    • Improved performance by removing redundant ones/ones_like calls when feature weights are not provided.
  • JAX Compatibility
    • Make KerasRS compatible with JAX >= 0.8.0, in particular with regards to jax.jit and shard_map.
  • Infrastructure & CI/CD:
    • Added TPU testing for JAX and TensorFlow.
    • Added GPU testing for JAX, TensorFlow and Torch.
    • Upgraded PyTorch dependency to version 2.9.0.
    • Enabled CUDA 13.0 support for GPU testing.

New Contributors

Full Changelog: v0.3.0...v0.4.0