[RFC] Integration of Distributed Inference into TorchChat #1376
Description
🚀 The feature, motivation and pitch
Overview
The goal of this RFC is to discuss the integration of distributed inference into TorchChat. Distributed inference leverages tensor parallelism or pipeline parallelism, or a combination of both to support larger model size which do not fit on a single accelerator. Through parallelization each model shard runs in its own worker process. The processes can either be spawned on the script level (e.g. via torchrun) or from within the main script. For online use cases like chat/server the processes need to coordinate fetching and sharing the user input depending on at which point the processes get spawned. Synchronization points between the processes should be minimized for optimal performance.
The design goals of the integration are:
- Support all CLI features of TorchChat (generate, chat, server)
- Minimize code duplication
- Maintain TorchChat's copy/pastebility
Alternatives
Option 1: Integrate at Model Level
While the usage of a tensor parallel model in PyTorchis is very much transparent, the current pipeline parallel API differs significantly from the usage of a local model. This option hides the distributed inference from the Generator class by introducing the distributed inference inside a torchchat.model.Model derivative. The DistributedModel(torchchat.model.Model) class would implement methods like call() and forward() and handle distribution to the worker processes inside.
- Pros:
- Code reuse high
- Transparent use of distributed model
- Virtually no changes in main Generator and OpenAiApiGenerator necessary
- Cons:
- In this scenario, sampling happens in the main script and thus the return value of the model (logits) need to be transferred between processes (i.e. moved to shared GPU memory)
- As the Generator is unaware of the parallelism the subprocesses would need to be spawned inside the model itself which is kind of ugly
Option 2: Abstract Base Class for Generator
Introduce a base class Generator which contains the common portions of the implementation generation process like getting and preparing input from the user. LocalGenerator and DistributedGenerator get introduced to handle specifics. The split between base and derivatives can be made at multiple levels, specifically High:Generator.generate, Mid:Generator.decode_n_tokens/prefill, Low: Generator.decode_one_token/prefill
- Pros:
- Introduces abstraction in the generation process
- High code reuse
- Subprocess creation for parallel workers can be on main script level
- Added complexity stays mostly separate from local generation
- Cons:
- Splitting up the Generator from main generate.py file will hurt copy/pastebility
- OpenAiApiGenerator (currently inherits from Generator) will require additional changes to work with distributed inference
Option 2b: Integrate at Low Level of Generator without base class
This approach skips the creation of a base class and directly inherits DistributedGenerator(Generator) and adds functionality for distributed inference in the main generate.py file.
- Pros:
- Fully reuses the functionality from existing Generator
- Subprocess creation for parallel workers can be on main script level
- Maintains copy/pastebility
- Cons:
- Some changes necessary in generate.py
- OpenAiApiGenerator (inherits from Generator) will require additional changes to work with distributed inference
cc @Jack-Khuu @byjlw @lessw2020
Additional context
No response
RFC (Optional)
No response