Code (sourced from this example): https://gist.github.com/abheesht17/0491a8d34756a208a000ff867ad7cc13
This is the error which crops up: https://gist.github.com/abheesht17/694a07e33d217b23692105e281a91fe4
This fixed the forward pass, but I don't know if this is the correct way to do it. Instead of sharding the dataset using Keras' distribution.distribute_dataset, shard it manually using JAX: https://github.com/AI-Hypercomputer/RecML/blob/dlrm_mlperf/examples/dlrm/dlrm_main.py#L347-L357. However, fit() fails because it tries to shard the dataset again.