Skip to content

Multi-host training does not work for distributed embeddings #143

@abheesht17

Description

@abheesht17

Code (sourced from this example): https://gist.github.com/abheesht17/0491a8d34756a208a000ff867ad7cc13

This is the error which crops up: https://gist.github.com/abheesht17/694a07e33d217b23692105e281a91fe4

This fixed the forward pass, but I don't know if this is the correct way to do it. Instead of sharding the dataset using Keras' distribution.distribute_dataset, shard it manually using JAX: https://github.com/AI-Hypercomputer/RecML/blob/dlrm_mlperf/examples/dlrm/dlrm_main.py#L347-L357. However, fit() fails because it tries to shard the dataset again.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions