Skip to content

Multi-host training does not work for distributed embeddings #143

@abheesht17

Description

@abheesht17

Code (sourced from this example): https://gist.github.com/abheesht17/0491a8d34756a208a000ff867ad7cc13

This is the error which crops up: https://gist.github.com/abheesht17/694a07e33d217b23692105e281a91fe4

This fixed the forward pass, but I don't know if this is the correct way to do it. Instead of sharding the dataset using Keras' distribution.distribute_dataset, shard it manually using JAX: https://github.com/AI-Hypercomputer/RecML/blob/dlrm_mlperf/examples/dlrm/dlrm_main.py#L347-L357. However, fit() fails because it tries to shard the dataset again.

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions