Open
Description
Colwise makes the logic a bit more clear. Rowwise splits on the token dimension, leading to confusion on how the different shards handle tokens that are not present within their shard. From a bit of debugging it seems like there is a special case for this somewhere deep in pytorch source code, but I could not find it.
With colwise, the embedding weight matrix is split on the model dim dimension, so all shards have all the tokens, just different parts of the model dim.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L133
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1),
),
Can someone provide some insight?