Skip to content

feat: add FSDP2 multi-GPU training backend with multi-tenant LoRA support#100

Open
xuanrui-L wants to merge 3 commits intomainfrom
feature/fsdp_support
Open

feat: add FSDP2 multi-GPU training backend with multi-tenant LoRA support#100
xuanrui-L wants to merge 3 commits intomainfrom
feature/fsdp_support

Conversation

@xuanrui-L
Copy link
Collaborator

Implement a fully-sharded data-parallel (FSDP2) training backend that distributes base models across multiple GPUs while supporting dynamic multi-tenant LoRA adapter management.

Core components:

  • FSDPTrainingWorker: per-GPU worker handling model construction, FSDP2 wrapping (MixedPrecisionPolicy bfloat16/float32), and multi-tenant adapter lifecycle (create/switch/remove) with gradient isolation
  • FSDPWorkerGroup: Ray-based orchestrator that manages worker actors, distributes data shards, and aggregates results across ranks
  • FSDPTrainingBackend: async interface bridging the training controller to the FSDP worker group
  • fsdp_utils: version-aware FSDP2 imports (PyTorch 2.4+), ABC compatibility patching for HuggingFace models, device mesh creation

Key design decisions:

  • Pre-wrap all decoder-layer linears with a default LoRA adapter before FSDP2 sharding so that subsequent add_adapter calls use PEFT's update_layer (in-place weight addition) instead of _replace_module (which would break FSDP2 parameter tracking)
  • LoRA parameters are regular tensors (not DTensors) and require explicit broadcast (init) and all_reduce (gradients) across ranks
  • Per-adapter optimizer instances and saved gradient buffers enable decoupled forward/backward/optim_step across tenant switches

Tests:

  • test_fsdp_training: basic loss decrease, multi-tenant isolation, checkpoint save/load, gradient isolation across adapters
  • test_fsdp_e2e: end-to-end server + backend tests covering dynamic adapter lifecycle, checkpoint resume, rapid switching, memory stability, decoupled operations, cross-rank parameter sync
  • test_fsdp_comprehensive: multi-rank adapters on one model, training state preservation across switches, sequential multi-model training, inference deployment via vLLM, GPU memory lifecycle monitoring

All test model paths are resolved from environment variables (TUFT_TEST_MODEL, FSDP_TEST_MODEL_A/B) with no hardcoded paths. Fixtures enforce >= 2 GPU availability and perform clear_ray_state before and after each test.

…port

Implement a fully-sharded data-parallel (FSDP2) training backend that
distributes base models across multiple GPUs while supporting dynamic
multi-tenant LoRA adapter management.

Core components:
- FSDPTrainingWorker: per-GPU worker handling model construction, FSDP2
  wrapping (MixedPrecisionPolicy bfloat16/float32), and multi-tenant
  adapter lifecycle (create/switch/remove) with gradient isolation
- FSDPWorkerGroup: Ray-based orchestrator that manages worker actors,
  distributes data shards, and aggregates results across ranks
- FSDPTrainingBackend: async interface bridging the training controller
  to the FSDP worker group
- fsdp_utils: version-aware FSDP2 imports (PyTorch 2.4+), ABC
  compatibility patching for HuggingFace models, device mesh creation

Key design decisions:
- Pre-wrap all decoder-layer linears with a default LoRA adapter before
  FSDP2 sharding so that subsequent add_adapter calls use PEFT's
  update_layer (in-place weight addition) instead of _replace_module
  (which would break FSDP2 parameter tracking)
- LoRA parameters are regular tensors (not DTensors) and require
  explicit broadcast (init) and all_reduce (gradients) across ranks
- Per-adapter optimizer instances and saved gradient buffers enable
  decoupled forward/backward/optim_step across tenant switches

Tests:
- test_fsdp_training: basic loss decrease, multi-tenant isolation,
  checkpoint save/load, gradient isolation across adapters
- test_fsdp_e2e: end-to-end server + backend tests covering dynamic
  adapter lifecycle, checkpoint resume, rapid switching, memory
  stability, decoupled operations, cross-rank parameter sync
- test_fsdp_comprehensive: multi-rank adapters on one model, training
  state preservation across switches, sequential multi-model training,
  inference deployment via vLLM, GPU memory lifecycle monitoring

All test model paths are resolved from environment variables
(TUFT_TEST_MODEL, FSDP_TEST_MODEL_A/B) with no hardcoded paths.
Fixtures enforce >= 2 GPU availability and perform clear_ray_state
before and after each test.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a new FSDP2 multi-GPU training backend, significantly improving the system's capability to handle large language models by distributing them across multiple GPUs. It introduces sophisticated multi-tenant support for LoRA adapters, allowing for efficient and isolated training of various adapters on a shared base model. The implementation includes a layered architecture for distributed operations, careful handling of framework-specific interactions, and robust mechanisms for state management and parameter synchronization, all validated through comprehensive testing.

Highlights

  • FSDP2 Multi-GPU Training Backend: Introduced a new Fully Sharded Data Parallel (FSDP2) training backend to distribute base models across multiple GPUs, enhancing scalability for large language models.
  • Multi-Tenant LoRA Support: Implemented dynamic multi-tenant LoRA adapter management, allowing multiple LoRA adapters to be trained concurrently on a single FSDP2-sharded base model with gradient isolation.
  • Distributed Architecture: Designed a robust distributed architecture comprising FSDPTrainingWorker (per-GPU model and adapter management), FSDPWorkerGroup (Ray-based orchestration), and FSDPTrainingBackend (asynchronous interface to the training controller).
  • PEFT and FSDP2 Compatibility: Addressed compatibility challenges between PEFT and FSDP2 by pre-wrapping all decoder-layer linear modules with a default LoRA adapter before FSDP2 sharding, ensuring proper parameter tracking.
  • LoRA Parameter Synchronization: Ensured explicit broadcast and all-reduce operations for LoRA parameters across ranks, as FSDP2 does not dynamically manage these, to maintain synchronization during training.
  • Comprehensive Testing: Included extensive test suites (test_fsdp_training, test_fsdp_e2e, test_fsdp_comprehensive) covering basic loss decrease, multi-tenant isolation, checkpointing, dynamic adapter lifecycle, memory stability, and inference deployment.
Changelog
  • src/tuft/backends/init.py
    • Imported FSDPTrainingBackend.
    • Added FSDPTrainingBackend to the all export list.
  • src/tuft/backends/base_backend.py
    • Updated backend creation logic to instantiate FSDPTrainingBackend when 'fsdp' is specified.
  • src/tuft/backends/fsdp_training_backend.py
    • Added the FSDPTrainingBackend class, implementing the BaseTrainingBackend interface for FSDP2.
  • src/tuft/backends/fsdp_training_worker.py
    • Added the FSDPTrainingWorker class, handling per-GPU FSDP2 model construction and multi-tenant adapter management.
  • src/tuft/backends/fsdp_utils.py
    • Added FSDP2 utility functions, including version-aware imports, ABC patching, and sharding logic.
  • src/tuft/backends/fsdp_worker_group.py
    • Added the FSDPWorkerGroup class, orchestrating FSDPTrainingWorker Ray actors for collective operations.
  • src/tuft/config.py
    • Extended ModelConfig with FSDP-specific training parameters (training_backend, num_gpus_per_node, num_nodes).
    • Added validation rules to prevent colocation with FSDP and ensure GPU availability for FSDP.
  • tests/test_fsdp_comprehensive.py
    • Added comprehensive tests for FSDP2 multi-GPU training scenarios, including multi-adapter training, state preservation, sequential multi-model training, vLLM inference deployment, and GPU memory monitoring.
  • tests/test_fsdp_e2e.py
    • Added end-to-end integration tests for the FSDP2 training backend, covering server and direct API interactions for various lifecycle and stress scenarios.
  • tests/test_fsdp_training.py
    • Added basic functional tests for the FSDP2 multi-GPU training backend, verifying loss decrease, multi-tenant concurrent training, checkpoint save/load, and gradient accumulation isolation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive FSDP2 training backend, enabling multi-GPU and multi-node training with support for dynamic multi-tenant LoRA adapters. The implementation is robust, demonstrating a deep understanding of FSDP2 intricacies, such as the pre-wrapping of LoRA layers and manual gradient synchronization. The accompanying tests are extensive and cover a wide range of scenarios, from basic functionality to complex end-to-end and stress tests, which provides high confidence in the changes.

My review focuses on a few areas for improvement:

  • A critical bug in data handling that could cause a crash with empty input batches.
  • A couple of maintainability concerns related to the use of a private Ray API and a generic exception handler.

Overall, this is an excellent and well-engineered feature addition.

Comment on lines +76 to +91
def forward_all(
self,
data: list[types.Datum],
lora_id: str,
loss_fn: types.LossFnType,
loss_fn_config: dict[str, float] | None,
backward: bool,
) -> types.ForwardBackwardOutput:
actual_sizes = self._actual_shard_sizes(len(data), self.total_gpus)
shards = self._split_data(data, self.total_gpus)
futures = [
w.forward.remote(shard, lora_id, loss_fn, loss_fn_config, backward)
for w, shard in zip(self.workers, shards, strict=True)
]
results: list[types.ForwardBackwardOutput] = ray.get(futures)
return self._merge_forward_results(results, actual_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _split_data method, called by forward_all, does not handle cases where the input data list is empty. This will lead to an IndexError at line 151 when trying to access data[-1]. You should add a check at the beginning of forward_all to handle empty data gracefully.

    def forward_all(
        self,
        data: list[types.Datum],
        lora_id: str,
        loss_fn: types.LossFnType,
        loss_fn_config: dict[str, float] | None,
        backward: bool,
    ) -> types.ForwardBackwardOutput:
        if not data:
            return types.ForwardBackwardOutput(
                loss_fn_output_type=loss_fn,
                loss_fn_outputs=[],
                metrics={},
            )
        actual_sizes = self._actual_shard_sizes(len(data), self.total_gpus)
        shards = self._split_data(data, self.total_gpus)
        futures = [
            w.forward.remote(shard, lora_id, loss_fn, loss_fn_config, backward)
            for w, shard in zip(self.workers, shards, strict=True)
        ]
        results: list[types.ForwardBackwardOutput] = ray.get(futures)
        return self._merge_forward_results(results, actual_sizes)

Comment on lines +569 to +586
try:
stacked = torch.stack(tensors)
loss_fn_input_dict[key] = stacked.to(device)
except Exception:
max_shape = list(tensors[0].shape)
for t in tensors:
for i, s in enumerate(t.shape):
if s > max_shape[i]:
max_shape[i] = s
padded_tensors = []
for t in tensors:
pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=False)]
pad_args: list[int] = []
for p in reversed(pad_width):
pad_args.extend(p)
padded_tensors.append(torch.nn.functional.pad(t, pad_args, value=0))
stacked = torch.stack(padded_tensors)
loss_fn_input_dict[key] = stacked.to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This try...except block has two potential issues:

  1. It catches a generic Exception, which can hide unrelated bugs. It's better to catch the specific exception that torch.stack might raise, which is typically a RuntimeError for shape mismatches.
  2. The zip on line 580 uses strict=False (the default). If tensors in the batch have a different number of dimensions, this will silently truncate the shorter one, which could lead to incorrect padding. Using strict=True would be safer and would raise an error in such cases, making the issue easier to debug.
Suggested change
try:
stacked = torch.stack(tensors)
loss_fn_input_dict[key] = stacked.to(device)
except Exception:
max_shape = list(tensors[0].shape)
for t in tensors:
for i, s in enumerate(t.shape):
if s > max_shape[i]:
max_shape[i] = s
padded_tensors = []
for t in tensors:
pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=False)]
pad_args: list[int] = []
for p in reversed(pad_width):
pad_args.extend(p)
padded_tensors.append(torch.nn.functional.pad(t, pad_args, value=0))
stacked = torch.stack(padded_tensors)
loss_fn_input_dict[key] = stacked.to(device)
try:
stacked = torch.stack(tensors)
loss_fn_input_dict[key] = stacked.to(device)
except RuntimeError: # torch.stack raises RuntimeError on shape mismatch
max_shape = list(tensors[0].shape)
for t in tensors:
for i, s in enumerate(t.shape):
if s > max_shape[i]:
max_shape[i] = s
padded_tensors = []
for t in tensors:
pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=True)]
pad_args: list[int] = []
for p in reversed(pad_width):
pad_args.extend(p)
padded_tensors.append(torch.nn.functional.pad(t, pad_args, value=0))
stacked = torch.stack(padded_tensors)
loss_fn_input_dict[key] = stacked.to(device)

Comment on lines +248 to +257
def _get_pg_ip(pg):
try:
table = ray._private.state.state.placement_group_table(pg.id)
bundles_to_node = table.get("bundles_to_node_id", {})
if bundles_to_node:
node_id = next(iter(bundles_to_node.values()))
return node_info.get(node_id, "")
except Exception:
logger.debug("Could not resolve PG IP", exc_info=True)
return ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function uses a private Ray API (ray._private.state.state.placement_group_table) to get node information for sorting placement groups. While this is a clever way to achieve deterministic rank assignment, relying on private APIs is risky as they can change without notice in future Ray versions, potentially breaking this functionality. The try-except block is a good mitigation, but the code's long-term stability is still at risk. Consider adding a comment to highlight this risk and the reason for using it.

except ImportError:
from torch.distributed._tensor import DTensor # type: ignore[no-redef]

logger = logging.getLogger(__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDPTrainingWorker is a ray actor, a global logger will not work.


Args:
world_size: total number of ranks.
fsdp_size: FSDP group size. If -1 or >= world_size, use pure FSDP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HSDP (fsdp_size != -1) is not used. Remove the fsdp_size if not necessary.

)
from torch.distributed.fsdp._fully_shard import _fully_shard as _fully_shard_module
from torch.distributed.tensor import Shard
elif version.parse(torch.__version__) >= version.parse("2.4"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch can be removed, since the oldest vLLM we support (0.10.2) requires torch 2.8.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants