Skip to content

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Oct 9, 2025

Summary:
Composability testing with TorchComms and distributed training in TorchTitan.

  • Training with torchcomms.new_comm
  • Device mesh initialization with torchcomms.init_device_mesh
  • Integration and testing with fully_shard

Differential Revision: D82171763

Test plan:
TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms

Loss curve:
running 1000 steps on llama3_8b.toml

Screenshot 2025-10-13 at 4 14 46 PM

Copy link

meta-codesync bot commented Oct 9, 2025

@mori360 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D82171763.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 9, 2025
@fduwjj fduwjj requested review from fegin, tianyu-l and wwwjn October 9, 2025 18:36
mori360 added a commit to mori360/torchtitan that referenced this pull request Oct 9, 2025
Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
@mori360 mori360 force-pushed the export-D82171763 branch 2 times, most recently from 36ab517 to 7c6435e Compare October 9, 2025 18:45
mori360 added a commit to mori360/torchtitan that referenced this pull request Oct 9, 2025
Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
mori360 added a commit to mori360/torchtitan that referenced this pull request Oct 9, 2025
Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The change looks interesting!

Is this PR for exploration or is it ready to ship to the community? If it's the former, could be start with branch / fork, instead of experiments?

Sorry to put a hold before we get more context.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK got clarifications offline. I think it's OK to host this experiment. To land, we'll need

  1. simplify the code, as I believe a lot existing components could be reused
  2. set up PoC for this folder (let's work together on this)

@mori360
Copy link
Contributor Author

mori360 commented Oct 9, 2025

Thanks for the PR! The change looks interesting!

Is this PR for exploration or is it ready to ship to the community? If it's the former, could be start with branch / fork, instead of experiments?

Sorry for the late reply, it's planed to ship to the community as a use case for torchcomms.

simplify the code, as I believe a lot existing components could be reused

I was cleaning the code, the main change here is

  1. init communication and device mesh in (class ParallelDimsForComms in) parallel_dims.py
  2. call ParallelDimsForComms in train.py
    For train.py, all the other parts are reused except init

set up PoC for this folder

How can I set up the PoC?

mori360 added a commit to mori360/torchtitan that referenced this pull request Oct 9, 2025
Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
@mori360 mori360 marked this pull request as draft October 9, 2025 21:25
mori360 added a commit to mori360/torchtitan that referenced this pull request Oct 9, 2025
Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
mori360 added a commit to mori360/torchtitan that referenced this pull request Oct 9, 2025
Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
@tianyu-l
Copy link
Contributor

tianyu-l commented Oct 9, 2025

How can I set up the PoC?

I'll do this with a PR shortly. Should we assign you as the PoC?

@mori360
Copy link
Contributor Author

mori360 commented Oct 9, 2025

How can I set up the PoC?

I'll do this with a PR shortly. Should we assign you as the PoC?

Yeah, please. There would be some further changes to enable other parallelisms and relative tests.

@mori360 mori360 marked this pull request as ready for review October 9, 2025 21:41
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is reasonable to duplicate ParallelDims._build_mesh_without_ep but Trainer.init seems to be mostly the same. And Trainer.init is very long. So it is not easy to debug the difference. Can you point out what changes in Trainer.init? We can brainstorm how to further minimize the duplications.

---
#### Example
```bash
TEST_BACKEND=nccl ./run_train.sh --model.name torchcomms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be correct. You will at least need to specify CONFIG_FILE.

- Training with `torchcomms.new_comm`
- Device mesh initialization with `torchcomms.init_device_mesh`
- **Composability Testing**
- Integration and testing with `fully_shard` (FSDP)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this FSDP2 only? I thought you also verified it with TP. cc., @fduwjj

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are working on ND now, will update readme later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin there are still some gaps on the N-D side, so we aim at first merging this PR with 1D only. This is to scale down the scope of this PR and then we will have more PRs down the road.

Comment on lines 54 to 71
# init distributed and build meshes
dist_utils.init_distributed(
job_config.comm,
enable_cpu_backend=job_config.training.enable_cpu_offload,
base_folder=job_config.job.dump_folder,
)
world_size = int(os.environ["WORLD_SIZE"])
parallelism_config = job_config.parallelism
self.parallel_dims = parallel_dims = ParallelDimsForComms(
dp_shard=parallelism_config.data_parallel_shard_degree,
dp_replicate=parallelism_config.data_parallel_replicate_degree,
cp=parallelism_config.context_parallel_degree,
tp=parallelism_config.tensor_parallel_degree,
pp=parallelism_config.pipeline_parallel_degree,
ep=parallelism_config.expert_parallel_degree,
etp=parallelism_config.expert_tensor_parallel_degree,
world_size=world_size,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iiuc, only this part of the initialization is changed. Is this correct? Or can you point out some other things you changed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, had some other changes before, but now that's the only changes.
Will try some way to call ParallelDimsForComms here but avoiding copy train.init

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can let the original Trainer have one class variable called parallel_dims_cls and use that variable in the init to construct self.parallel_dims = parallel_dims. Then you can just create a CommTrainer and replace that class variable.

Another approach is to make the following code as a method, def create_parallel_dims(self, config) -> None:.

        self.parallel_dims = parallel_dims = ParallelDimsForComms(
            dp_shard=parallelism_config.data_parallel_shard_degree,
            dp_replicate=parallelism_config.data_parallel_replicate_degree,
            cp=parallelism_config.context_parallel_degree,
            tp=parallelism_config.tensor_parallel_degree,
            pp=parallelism_config.pipeline_parallel_degree,
            ep=parallelism_config.expert_parallel_degree,
            etp=parallelism_config.expert_tensor_parallel_degree,
            world_size=world_size,
        )

Both approaches should work. cc., @tianyu-l @wwwjn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds OK. I prefer the second option as it sounds a bit more straightforward. Maybe should call it _create_parallel_dims as it's not supposed to be called outside.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mori360
Once #1859 lands, please rebase and add owner to the experiment

@mori360
Copy link
Contributor Author

mori360 commented Oct 13, 2025

@mori360 Once #1859 lands, please rebase and add owner to the experiment

Thanks for the reminder, can I add @d4l3k and @fduwjj as the owner as well?

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you run a job and paste the loss curve from tensor board here?

Summary:

Composability testing with TorchComms and distributed training in TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763
Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchComm integration side looks good to me and will let @tianyu-l and @fegin to decide on the Titan integration part.

@fduwjj
Copy link
Contributor

fduwjj commented Oct 13, 2025

Also we will have more converge and perf test down the road as follow-up PRs.

@@ -0,0 +1,20 @@
# TorchTitan & TorchComms Composability Testing

This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently in bold font and look a bit obtrusive. Could you adjust them to use plain font?

@@ -0,0 +1,20 @@
# TorchTitan & TorchComms Composability Testing

This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch.
This folder provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch.


This repository provides a framework for composability testing with **TorchComms** and distributed training in **TorchTitan**. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch.
---
#### Example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention that the command below uses Llama 3 as an example, but should work on all models.

---
#### Example
```bash
TEST_BACKEND={backend} TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TEST_BACKEND={backend}

What should this be?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users can input backend they want to use, e.g. nccl or other backend
It's a bit confusing here, will change to TEST_BACKEND=nccl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users can input backend they want to use, e.g. nccl or other backend
It's a bit confusing here, will change to TEST_BACKEND=nccl

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we mention all the available backends? From the readme it's hard to tell what people should put here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's mention nccl, gloo or any other user defined customized backend for now. Also let's mention that the user customized backend needs to implement torchComm wrapper. (We just don't mention the backend which cannot be mentioned at this moment.)

from torchtitan.models.llama3.infra.parallelize import parallelize_llama
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

register_train_spec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need to register this TrainSpec?

---
#### Example
```bash
TEST_BACKEND={backend} TRAIN_FILE=torchtitan.experiments.torchcomms.train ./run_train.sh --model.name torchcomms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set CONFIG_FILE here, too. You can refer to examples in main README.md

f"(warmup {job_config.lr_scheduler.warmup_steps})"
)

def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:

class CommsTrainer(Trainer):
parallel_dims: ParallelDimsForComms

def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:

from .parallel_dims import ParallelDimsForComms


class CommsTrainer(Trainer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class CommsTrainer(Trainer):
class TorchCommsTrainer(Trainer):



@dataclass
class ParallelDimsForComms(ParallelDims):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ParallelDimsForComms(ParallelDims):
class TorchCommsParallelDims(ParallelDims):

@mori360 mori360 requested a review from tianyu-l October 14, 2025 01:43
Copy link

meta-codesync bot commented Oct 14, 2025

@mori360 has imported this pull request. If you are a Meta employee, you can view this in D82171763.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we remove this file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this file for now.

@@ -0,0 +1,22 @@
# TorchTitan & TorchComms Composability Testing

This folder provides a framework for composability testing with TorchComms and distributed training in TorchTitan. The goal is to enable flexible experimentation with distributed communication primitives and parallelism strategies in PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems font hasn't been fixed

torchcomms Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this module?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, we cannot mention too much details. We will add more context when it goes public. We need to merge this PR first so that the titan integration can go with the release of torchcomm.

@mori360 let's add a TODO here to add more explanation once the torchcomm goes public.

@fduwjj
Copy link
Contributor

fduwjj commented Oct 14, 2025

Looks like you have lint error as well?

@mori360 mori360 requested review from fduwjj and tianyu-l October 14, 2025 02:39
@fduwjj
Copy link
Contributor

fduwjj commented Oct 14, 2025

if you choose this:

image

Would that help make CI happy?

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this, looks good to me now.

- Integration and testing with `fully_shard` (FSDP)
---
### To Be Added
- Integration and testing with additional parallelism strategies (e.g., tensor, pipeline, model parallelism) other than fully_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you remove model parallelism or replace it with context parallelism? Thanks

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tianyu-l tianyu-l merged commit cd304c7 into pytorch:main Oct 14, 2025
10 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants