Closed
Description
openmpi script, launch cli
mpirun \
-np $TOTAL_NUM_GPUS \
-H \$MPI_HOST_STRING \
-x PATH \
-bind-to none \
-map-by slot \
--mca pml ob1 --mca btl ^openib \
--display-allocation \
--display-map \
python3 src/full_finetune_distributed.py \
--config config_files/8B_full_distributed.yaml \
optimizer_in_bwd=False
full_finetune_distributed.py
if int(os.environ.get("NUM_NODES")) > 1:
from torch.distributed._tensor import init_device_mesh
mesh_2d = init_device_mesh("cuda",
mesh_shape=(int(os.environ.get("NUM_NODES")),
int(os.environ['WORLD_SIZE']) // 2),
mesh_dim_names=("dp", "tp"))
else:
mesh_2d = None
training.shard_model(
model=model,
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
mesh=mesh_2d,
)
_distributed.py
def shard_model(
model: TransformerDecoder,
shard_conditions: List[Callable[[str, nn.Module], bool]],
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
mesh: Optional[DeviceMesh] = None # <-- Add this line
) -> None:
if mesh is not None: # <-- Add this line
fsdp_kwargs["mesh"] = mesh # <-- Add this line
Originally posted by @fabiogeraci in #2018 (comment)