feat: add FSDP2 multi-GPU training backend with multi-tenant LoRA support#100
feat: add FSDP2 multi-GPU training backend with multi-tenant LoRA support#100
Conversation
…port Implement a fully-sharded data-parallel (FSDP2) training backend that distributes base models across multiple GPUs while supporting dynamic multi-tenant LoRA adapter management. Core components: - FSDPTrainingWorker: per-GPU worker handling model construction, FSDP2 wrapping (MixedPrecisionPolicy bfloat16/float32), and multi-tenant adapter lifecycle (create/switch/remove) with gradient isolation - FSDPWorkerGroup: Ray-based orchestrator that manages worker actors, distributes data shards, and aggregates results across ranks - FSDPTrainingBackend: async interface bridging the training controller to the FSDP worker group - fsdp_utils: version-aware FSDP2 imports (PyTorch 2.4+), ABC compatibility patching for HuggingFace models, device mesh creation Key design decisions: - Pre-wrap all decoder-layer linears with a default LoRA adapter before FSDP2 sharding so that subsequent add_adapter calls use PEFT's update_layer (in-place weight addition) instead of _replace_module (which would break FSDP2 parameter tracking) - LoRA parameters are regular tensors (not DTensors) and require explicit broadcast (init) and all_reduce (gradients) across ranks - Per-adapter optimizer instances and saved gradient buffers enable decoupled forward/backward/optim_step across tenant switches Tests: - test_fsdp_training: basic loss decrease, multi-tenant isolation, checkpoint save/load, gradient isolation across adapters - test_fsdp_e2e: end-to-end server + backend tests covering dynamic adapter lifecycle, checkpoint resume, rapid switching, memory stability, decoupled operations, cross-rank parameter sync - test_fsdp_comprehensive: multi-rank adapters on one model, training state preservation across switches, sequential multi-model training, inference deployment via vLLM, GPU memory lifecycle monitoring All test model paths are resolved from environment variables (TUFT_TEST_MODEL, FSDP_TEST_MODEL_A/B) with no hardcoded paths. Fixtures enforce >= 2 GPU availability and perform clear_ray_state before and after each test.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new FSDP2 multi-GPU training backend, significantly improving the system's capability to handle large language models by distributing them across multiple GPUs. It introduces sophisticated multi-tenant support for LoRA adapters, allowing for efficient and isolated training of various adapters on a shared base model. The implementation includes a layered architecture for distributed operations, careful handling of framework-specific interactions, and robust mechanisms for state management and parameter synchronization, all validated through comprehensive testing. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive FSDP2 training backend, enabling multi-GPU and multi-node training with support for dynamic multi-tenant LoRA adapters. The implementation is robust, demonstrating a deep understanding of FSDP2 intricacies, such as the pre-wrapping of LoRA layers and manual gradient synchronization. The accompanying tests are extensive and cover a wide range of scenarios, from basic functionality to complex end-to-end and stress tests, which provides high confidence in the changes.
My review focuses on a few areas for improvement:
- A critical bug in data handling that could cause a crash with empty input batches.
- A couple of maintainability concerns related to the use of a private Ray API and a generic exception handler.
Overall, this is an excellent and well-engineered feature addition.
| def forward_all( | ||
| self, | ||
| data: list[types.Datum], | ||
| lora_id: str, | ||
| loss_fn: types.LossFnType, | ||
| loss_fn_config: dict[str, float] | None, | ||
| backward: bool, | ||
| ) -> types.ForwardBackwardOutput: | ||
| actual_sizes = self._actual_shard_sizes(len(data), self.total_gpus) | ||
| shards = self._split_data(data, self.total_gpus) | ||
| futures = [ | ||
| w.forward.remote(shard, lora_id, loss_fn, loss_fn_config, backward) | ||
| for w, shard in zip(self.workers, shards, strict=True) | ||
| ] | ||
| results: list[types.ForwardBackwardOutput] = ray.get(futures) | ||
| return self._merge_forward_results(results, actual_sizes) |
There was a problem hiding this comment.
The _split_data method, called by forward_all, does not handle cases where the input data list is empty. This will lead to an IndexError at line 151 when trying to access data[-1]. You should add a check at the beginning of forward_all to handle empty data gracefully.
def forward_all(
self,
data: list[types.Datum],
lora_id: str,
loss_fn: types.LossFnType,
loss_fn_config: dict[str, float] | None,
backward: bool,
) -> types.ForwardBackwardOutput:
if not data:
return types.ForwardBackwardOutput(
loss_fn_output_type=loss_fn,
loss_fn_outputs=[],
metrics={},
)
actual_sizes = self._actual_shard_sizes(len(data), self.total_gpus)
shards = self._split_data(data, self.total_gpus)
futures = [
w.forward.remote(shard, lora_id, loss_fn, loss_fn_config, backward)
for w, shard in zip(self.workers, shards, strict=True)
]
results: list[types.ForwardBackwardOutput] = ray.get(futures)
return self._merge_forward_results(results, actual_sizes)| try: | ||
| stacked = torch.stack(tensors) | ||
| loss_fn_input_dict[key] = stacked.to(device) | ||
| except Exception: | ||
| max_shape = list(tensors[0].shape) | ||
| for t in tensors: | ||
| for i, s in enumerate(t.shape): | ||
| if s > max_shape[i]: | ||
| max_shape[i] = s | ||
| padded_tensors = [] | ||
| for t in tensors: | ||
| pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=False)] | ||
| pad_args: list[int] = [] | ||
| for p in reversed(pad_width): | ||
| pad_args.extend(p) | ||
| padded_tensors.append(torch.nn.functional.pad(t, pad_args, value=0)) | ||
| stacked = torch.stack(padded_tensors) | ||
| loss_fn_input_dict[key] = stacked.to(device) |
There was a problem hiding this comment.
This try...except block has two potential issues:
- It catches a generic
Exception, which can hide unrelated bugs. It's better to catch the specific exception thattorch.stackmight raise, which is typically aRuntimeErrorfor shape mismatches. - The
zipon line 580 usesstrict=False(the default). If tensors in the batch have a different number of dimensions, this will silently truncate the shorter one, which could lead to incorrect padding. Usingstrict=Truewould be safer and would raise an error in such cases, making the issue easier to debug.
| try: | |
| stacked = torch.stack(tensors) | |
| loss_fn_input_dict[key] = stacked.to(device) | |
| except Exception: | |
| max_shape = list(tensors[0].shape) | |
| for t in tensors: | |
| for i, s in enumerate(t.shape): | |
| if s > max_shape[i]: | |
| max_shape[i] = s | |
| padded_tensors = [] | |
| for t in tensors: | |
| pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=False)] | |
| pad_args: list[int] = [] | |
| for p in reversed(pad_width): | |
| pad_args.extend(p) | |
| padded_tensors.append(torch.nn.functional.pad(t, pad_args, value=0)) | |
| stacked = torch.stack(padded_tensors) | |
| loss_fn_input_dict[key] = stacked.to(device) | |
| try: | |
| stacked = torch.stack(tensors) | |
| loss_fn_input_dict[key] = stacked.to(device) | |
| except RuntimeError: # torch.stack raises RuntimeError on shape mismatch | |
| max_shape = list(tensors[0].shape) | |
| for t in tensors: | |
| for i, s in enumerate(t.shape): | |
| if s > max_shape[i]: | |
| max_shape[i] = s | |
| padded_tensors = [] | |
| for t in tensors: | |
| pad_width = [(0, m - s) for s, m in zip(t.shape, max_shape, strict=True)] | |
| pad_args: list[int] = [] | |
| for p in reversed(pad_width): | |
| pad_args.extend(p) | |
| padded_tensors.append(torch.nn.functional.pad(t, pad_args, value=0)) | |
| stacked = torch.stack(padded_tensors) | |
| loss_fn_input_dict[key] = stacked.to(device) |
| def _get_pg_ip(pg): | ||
| try: | ||
| table = ray._private.state.state.placement_group_table(pg.id) | ||
| bundles_to_node = table.get("bundles_to_node_id", {}) | ||
| if bundles_to_node: | ||
| node_id = next(iter(bundles_to_node.values())) | ||
| return node_info.get(node_id, "") | ||
| except Exception: | ||
| logger.debug("Could not resolve PG IP", exc_info=True) | ||
| return "" |
There was a problem hiding this comment.
This function uses a private Ray API (ray._private.state.state.placement_group_table) to get node information for sorting placement groups. While this is a clever way to achieve deterministic rank assignment, relying on private APIs is risky as they can change without notice in future Ray versions, potentially breaking this functionality. The try-except block is a good mitigation, but the code's long-term stability is still at risk. Consider adding a comment to highlight this risk and the reason for using it.
| except ImportError: | ||
| from torch.distributed._tensor import DTensor # type: ignore[no-redef] | ||
|
|
||
| logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
FSDPTrainingWorker is a ray actor, a global logger will not work.
|
|
||
| Args: | ||
| world_size: total number of ranks. | ||
| fsdp_size: FSDP group size. If -1 or >= world_size, use pure FSDP |
There was a problem hiding this comment.
HSDP (fsdp_size != -1) is not used. Remove the fsdp_size if not necessary.
| ) | ||
| from torch.distributed.fsdp._fully_shard import _fully_shard as _fully_shard_module | ||
| from torch.distributed.tensor import Shard | ||
| elif version.parse(torch.__version__) >= version.parse("2.4"): |
There was a problem hiding this comment.
This branch can be removed, since the oldest vLLM we support (0.10.2) requires torch 2.8.0
Implement a fully-sharded data-parallel (FSDP2) training backend that distributes base models across multiple GPUs while supporting dynamic multi-tenant LoRA adapter management.
Core components:
Key design decisions:
Tests:
All test model paths are resolved from environment variables (TUFT_TEST_MODEL, FSDP_TEST_MODEL_A/B) with no hardcoded paths. Fixtures enforce >= 2 GPU availability and perform clear_ray_state before and after each test.