-
Notifications
You must be signed in to change notification settings - Fork 561
TorchTitan e2e test on torchcomms device mesh #1847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
36ab517
to
7c6435e
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
7c6435e
to
8998f1c
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! The change looks interesting!
Is this PR for exploration or is it ready to ship to the community? If it's the former, could be start with branch / fork, instead of experiments?
Sorry to put a hold before we get more context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK got clarifications offline. I think it's OK to host this experiment. To land, we'll need
- simplify the code, as I believe a lot existing components could be reused
- set up PoC for this folder (let's work together on this)
Sorry for the late reply, it's planed to ship to the community as a use case for torchcomms.
I was cleaning the code, the main change here is
How can I set up the PoC? |
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
8998f1c
to
a6b3a47
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
a6b3a47
to
09e6610
Compare
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
09e6610
to
288f37f
Compare
I'll do this with a PR shortly. Should we assign you as the PoC? |
Yeah, please. There would be some further changes to enable other parallelisms and relative tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is reasonable to duplicate ParallelDims._build_mesh_without_ep but Trainer.init seems to be mostly the same. And Trainer.init is very long. So it is not easy to debug the difference. Can you point out what changes in Trainer.init? We can brainstorm how to further minimize the duplications.
--- | ||
#### Example | ||
```bash | ||
TEST_BACKEND=nccl ./run_train.sh --model.name torchcomms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be correct. You will at least need to specify CONFIG_FILE.
- Training with `torchcomms.new_comm` | ||
- Device mesh initialization with `torchcomms.init_device_mesh` | ||
- **Composability Testing** | ||
- Integration and testing with `fully_shard` (FSDP) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this FSDP2 only? I thought you also verified it with TP. cc., @fduwjj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are working on ND now, will update readme later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fegin there are still some gaps on the N-D side, so we aim at first merging this PR with 1D only. This is to scale down the scope of this PR and then we will have more PRs down the road.
# init distributed and build meshes | ||
dist_utils.init_distributed( | ||
job_config.comm, | ||
enable_cpu_backend=job_config.training.enable_cpu_offload, | ||
base_folder=job_config.job.dump_folder, | ||
) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
parallelism_config = job_config.parallelism | ||
self.parallel_dims = parallel_dims = ParallelDimsForComms( | ||
dp_shard=parallelism_config.data_parallel_shard_degree, | ||
dp_replicate=parallelism_config.data_parallel_replicate_degree, | ||
cp=parallelism_config.context_parallel_degree, | ||
tp=parallelism_config.tensor_parallel_degree, | ||
pp=parallelism_config.pipeline_parallel_degree, | ||
ep=parallelism_config.expert_parallel_degree, | ||
etp=parallelism_config.expert_tensor_parallel_degree, | ||
world_size=world_size, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc, only this part of the initialization is changed. Is this correct? Or can you point out some other things you changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, had some other changes before, but now that's the only changes.
Will try some way to call ParallelDimsForComms
here but avoiding copy train.init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can let the original Trainer have one class variable called parallel_dims_cls
and use that variable in the init to construct self.parallel_dims = parallel_dims
. Then you can just create a CommTrainer and replace that class variable.
Another approach is to make the following code as a method, def create_parallel_dims(self, config) -> None:
.
self.parallel_dims = parallel_dims = ParallelDimsForComms(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
etp=parallelism_config.expert_tensor_parallel_degree,
world_size=world_size,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds OK. I prefer the second option as it sounds a bit more straightforward. Maybe should call it _create_parallel_dims
as it's not supposed to be called outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you run a job and paste the loss curve from tensor board here?
Summary: Composability testing with TorchComms and distributed training in TorchTitan. - Training with `torchcomms.new_comm` - Device mesh initialization with `torchcomms.init_device_mesh` - Integration and testing with `fully_shard` Differential Revision: D82171763
7de8101
to
b270b5a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we will have more converge and perf test down the road as follow-up PRs. |
@@ -0,0 +1,20 @@ | |||
# TorchTitan & TorchComms Composability Testing | |||
|
|||
This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is currently in bold font and look a bit obtrusive. Could you adjust them to use plain font?
@@ -0,0 +1,20 @@ | |||
# TorchTitan & TorchComms Composability Testing | |||
|
|||
This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | |
This folder provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. |
|
||
This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. | ||
--- | ||
#### Example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mention that the command below uses Llama 3 as an example, but should work on all models.
--- | ||
#### Example | ||
```bash | ||
TEST_BACKEND={backend} TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TEST_BACKEND={backend}
What should this be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
users can input backend they want to use, e.g. nccl or other backend
It's a bit confusing here, will change to TEST_BACKEND=nccl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
users can input backend they want to use, e.g. nccl or other backend
It's a bit confusing here, will change to TEST_BACKEND=nccl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we mention all the available backends? From the readme it's hard to tell what people should put here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's mention nccl, gloo or any other user defined customized backend for now. Also let's mention that the user customized backend needs to implement torchComm wrapper. (We just don't mention the backend which cannot be mentioned at this moment.)
from torchtitan.models.llama3.infra.parallelize import parallelize_llama | ||
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec | ||
|
||
register_train_spec( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need to register this TrainSpec
?
--- | ||
#### Example | ||
```bash | ||
TEST_BACKEND={backend} TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's set CONFIG_FILE
here, too. You can refer to examples in main README.md
torchtitan/train.py
Outdated
f"(warmup {job_config.lr_scheduler.warmup_steps})" | ||
) | ||
|
||
def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: | |
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
class CommsTrainer(Trainer): | ||
parallel_dims: ParallelDimsForComms | ||
|
||
def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: | |
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: |
from .parallel_dims import ParallelDimsForComms | ||
|
||
|
||
class CommsTrainer(Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class CommsTrainer(Trainer): | |
class TorchCommsTrainer(Trainer): |
|
||
|
||
@dataclass | ||
class ParallelDimsForComms(ParallelDims): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ParallelDimsForComms(ParallelDims): | |
class TorchCommsParallelDims(ParallelDims): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we remove this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this file for now.
@@ -0,0 +1,22 @@ | |||
# TorchTitan & TorchComms Composability Testing | |||
|
|||
This folder provides a framework for composability testing with TorchComms and distributed training in TorchTitan. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems font hasn't been fixed
torchcomms
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, we cannot mention too much details. We will add more context when it goes public. We need to merge this PR first so that the titan integration can go with the release of torchcomm.
@mori360 let's add a TODO here to add more explanation once the torchcomm goes public.
Looks like you have lint error as well? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing this, looks good to me now.
- Integration and testing with `fully_shard` (FSDP) | ||
--- | ||
### To Be Added | ||
- Integration and testing with additional parallelism strategies (e.g., tensor, pipeline, model parallelism) other than fully_shard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you remove model parallelism or replace it with context parallelism? Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary:
Composability testing with TorchComms and distributed training in TorchTitan.
torchcomms.new_comm
torchcomms.init_device_mesh
fully_shard
Differential Revision: D82171763
Test plan:
TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms
Loss curve:
running 1000 steps on llama3_8b.toml