Skip to content

Spmd pre-training llama2 multi-machine training so slow? #6778

Open
@mars1248

Description

spmd has a normal training speed using eight blocks on a single machine, but the communication overhead increases rapidly in the case of multiple machines
device is:
gpu:A100 * 8 * 2
spmd strategy is:

for name, param in model.named_parameters():
    shape = (num_devices,) + (1,) * (len(param.shape) - 1)
    mesh = xs.Mesh(device_ids, shape)
    xs.mark_sharding(param, mesh, range(len(param.shape)))

profile result is:

image

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions