-
Notifications
You must be signed in to change notification settings - Fork 651
Support TP when using vLLM engine to run inference w/ torchtitan model definition #2165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| return parallel_dims | ||
|
|
||
|
|
||
| def create_job_config_from_vllm_config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, it should be the other direction? we should construct a vllm config from torchtitan inference config
I guess the possible concern may be: to init a vLLM model one has to pass in a vLLM config -- but in RL you'd specify that from the RL orchestrator, in this case, also torchtitan. So it's even fine if we pass in a fake config and not use it (here you are doing translation back to torchtitan config, which you could've passed in from caller anyways)? What does the vLLM config do, actually?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to init a vLLM model one has to pass in a vLLM config
Yes this is the problem for inference now - vLLM Engine spin up the model, and the only config vllm Engine could pass into the model is vLLMConfig. Currently the inference process is vLLM Engine -> our Model -> vLLM Engine -> our Model -> etc.
So it's even fine if we pass in a fake config and not use it (here you are doing translation back to torchtitan config, which you could've passed in from caller anyways)?
Do you mean during current inference, we pass a fake torchtitan config, but don't use it. During inference, we use vllm config fields directly to create mesh, apply parallel, etc? Here I convert it back because the interface of parallelize_fn in torchtitan requires a JobConfig object.
What does the vLLM config do, actually?
vLLM Engine only pass in VLLMConfig for model initialization. vLLM config is mainly used by vLLM Engine / vLLM model runner . In our model code, we mainly need to know tp_size to apply parallel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In our model code, we mainly need to know tp_size to apply parallel.
If you are the one who constructs the model, you don't have to translate into vllm config to know tp_size.
vLLM Engine only pass in VLLMConfig for model initialization. vLLM config is mainly used by vLLM Engine / vLLM model runner
That's my question -- how vllm engine is using this config? As we already construct model and apply parallelism by ourselves, can we assume vllm config is useless and we can pass in random garbage and it will get ignored?
If it's for construct processes, then we only need to set the tp_size right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently the inference process is vLLM Engine -> our Model -> vLLM Engine -> our Model -> etc.
this sounds unnecessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this sounds unnecessary
This interleaved control flow is because vllm used continuous batching, which will call our vllm_wrapper.forward() first, and then slicing and only get last token from each request to compute logprob. I think this behavior can not be fully avoided unless we want to unwrap the current vllm engine code and write our own.
how vllm engine is using this config? As we already construct model and apply parallelism by ourselves, can we assume vllm config is useless and we can pass in random garbage and it will get ignored?
vllm engine internally use vLLMConfig a lot, eg link. But I agree in model/parallelism part, the only important information from vllm_config is parallelism degrees, which we should also be able to get from RL orchestrator. With the RL orchestrator, we can safely assume the vllm_config in TorchTitanVLLMWrapper.__init__(*, vllm_config) (passed from vllm engine) can be ignored.
A side note is: With RL orchestrator in torchtitan, I think we only need to do one-directional conversion: torchtitan JobConfig -> vllmConfig.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by any means we should let torchtitan config be single source of truth. We probably should do:
- torchtitan config -> vllm version of torchtitan model config -> calling titan parallelize functions
- torchtitan config -> vllm config
We should avoid
- torchtitan config -> vllm config -> vllm version of torchtitan model config -> calling titan parallelize functions
| ) -> torch.Tensor | None: | ||
| """Compute logits from hidden states.""" | ||
| if self.parallel_dims.tp_enabled: | ||
| # Turn hidden_states back to DTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to make this converting to DTensor happen in prepare_module_input for the norm layer? This avoids mixing parallelizing model with the model code itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a great question, technically we could do that using PrepareModuleInput to do this. The reason I do this manually here is that I want to keep vllm relate changes within the wrapper. I want to keep the TP plan as the same as training TP plan.
There are 2 places in the wrapper I do the manual conversion between DTensor <> plain tensor:
- At the end of forward(): The original output of transformers is Shard(1), and I manually all-gather it to be full tensor. This result will be passed to vllm engine.
- At the beginning of
compute_logits(): The input is a plain tensor passed from vllm engine, which is a slices from 1). As we did AG in 1), so the tensor here is Replicate()
Without these 2 conversion, the output of transformers will be Shard(1) and be passed into self.norms (which is expecting a Shard(1) DTensor as input). If we put 1) and 2) into TP plan, it will cause a uncessary AG if we apply the same plan in training.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the reason why we manually all_gather to get full tensor and later create DTensor is because vllm requires the function to return and takes in plain tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's what I learned from llama code and vllm engine code
| """ | ||
| assert hasattr( | ||
| self.model, "layers" | ||
| ), f"Model {type(self.model).__name__} must have .layers attribute" | ||
|
|
||
| tp_size = self.parallel_dims.tp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we should check this when creating the model, not here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing tp_size here is not for checking but we are passing it into VLLMAttention initialization because we need local number of heads / num kv heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sry I meant the check below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sure! Let me remove it to init
| for layer in self.model.layers.values(): | ||
| h = layer(h, rope_cache, attention_masks=None, positions=positions) | ||
|
|
||
| # When parallelism is applied, get full tensor before return to vLLM Engine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on
- what is the placement of
hbefore converging - why vLLM needs (Replicate) full tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems you didn't answer the questions.
| hidden_states, | ||
| device_mesh=self.parallel_dims.get_mesh("tp"), | ||
| placements=[ | ||
| Replicate(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on what happens in vllm engine, and why it should be Replicate here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same -- didn't answer the questions
As tiled.
To support TP, now we keep a seperate TP plan for Qwen3 model. There are 2 main difference with torchtitan core Qwen3 TP plan:
TODO: Add numerics check