New features
- keras_rs.losses.ListMLELoss: Added support for the ListMLE loss function.
New example
- DLRM-DCNv2: Added implementation of the DLRM-DCNv2 model architecture with the MLPerf performance benchmark using the Criteo dataset.
Bug fixes & Improvements
- DistributedEmbedding on JAX
- Updated
DistributedEmbeddingto work with the latestjax-tpu-embeddingand take advantage of the latest performance improvements, in particular with table stacking. - Bug fixes and improvements with the synchronization of table statistics across multiple hosts.
- Made table statistics updates optional.
- Improved performance by removing redundant
ones/ones_likecalls when feature weights are not provided.
- Updated
- JAX Compatibility
- Make KerasRS compatible with JAX >= 0.8.0, in particular with regards to
jax.jitandshard_map.
- Make KerasRS compatible with JAX >= 0.8.0, in particular with regards to
- Infrastructure & CI/CD:
- Added TPU testing for JAX and TensorFlow.
- Added GPU testing for JAX, TensorFlow and Torch.
- Upgraded PyTorch dependency to version 2.9.0.
- Enabled CUDA 13.0 support for GPU testing.
New Contributors
- @adityagupta1089 made their first contribution in #156
- @LakshmiKalaKadali made their first contribution in #130
- @wenyi-guo made their first contribution in #170
Full Changelog: v0.3.0...v0.4.0