Skip to content

Why use RowwiseParallel for nn.Embedding instead of ColwiseParallel? #785

Open
@corey-lambda

Description

@corey-lambda

Colwise makes the logic a bit more clear. Rowwise splits on the token dimension, leading to confusion on how the different shards handle tokens that are not present within their shard. From a bit of debugging it seems like there is a special case for this somewhere deep in pytorch source code, but I could not find it.

With colwise, the embedding weight matrix is split on the model dim dimension, so all shards have all the tokens, just different parts of the model dim.

https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L133

    parallelize_module(
        model,
        tp_mesh,
        {
            "tok_embeddings": RowwiseParallel(
                input_layouts=Replicate(),
                output_layouts=Shard(1),
            ),

Can someone provide some insight?

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions