Open
Description
spmd has a normal training speed using eight blocks on a single machine, but the communication overhead increases rapidly in the case of multiple machines
device is:
gpu:A100 * 8 * 2
spmd strategy is:
for name, param in model.named_parameters():
shape = (num_devices,) + (1,) * (len(param.shape) - 1)
mesh = xs.Mesh(device_ids, shape)
xs.mark_sharding(param, mesh, range(len(param.shape)))
profile result is: