Skip to content

Conversation

@wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Dec 19, 2025

As tiled.

To support TP, now we keep a seperate TP plan for Qwen3 model. There are 2 main difference with torchtitan core Qwen3 TP plan:

  1. Use all DTensor in TP region
  2. Add PrepareModuleInputOuput annotation for innner_attention (vllm.Attention())

TODO: Add numerics check

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 19, 2025
return parallel_dims


def create_job_config_from_vllm_config(
Copy link
Contributor

@tianyu-l tianyu-l Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, it should be the other direction? we should construct a vllm config from torchtitan inference config

I guess the possible concern may be: to init a vLLM model one has to pass in a vLLM config -- but in RL you'd specify that from the RL orchestrator, in this case, also torchtitan. So it's even fine if we pass in a fake config and not use it (here you are doing translation back to torchtitan config, which you could've passed in from caller anyways)? What does the vLLM config do, actually?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to init a vLLM model one has to pass in a vLLM config

Yes this is the problem for inference now - vLLM Engine spin up the model, and the only config vllm Engine could pass into the model is vLLMConfig. Currently the inference process is vLLM Engine -> our Model -> vLLM Engine -> our Model -> etc.

So it's even fine if we pass in a fake config and not use it (here you are doing translation back to torchtitan config, which you could've passed in from caller anyways)?

Do you mean during current inference, we pass a fake torchtitan config, but don't use it. During inference, we use vllm config fields directly to create mesh, apply parallel, etc? Here I convert it back because the interface of parallelize_fn in torchtitan requires a JobConfig object.

What does the vLLM config do, actually?
vLLM Engine only pass in VLLMConfig for model initialization. vLLM config is mainly used by vLLM Engine / vLLM model runner . In our model code, we mainly need to know tp_size to apply parallel.

Copy link
Contributor

@tianyu-l tianyu-l Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our model code, we mainly need to know tp_size to apply parallel.

If you are the one who constructs the model, you don't have to translate into vllm config to know tp_size.

vLLM Engine only pass in VLLMConfig for model initialization. vLLM config is mainly used by vLLM Engine / vLLM model runner

That's my question -- how vllm engine is using this config? As we already construct model and apply parallelism by ourselves, can we assume vllm config is useless and we can pass in random garbage and it will get ignored?

If it's for construct processes, then we only need to set the tp_size right.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the inference process is vLLM Engine -> our Model -> vLLM Engine -> our Model -> etc.

this sounds unnecessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds unnecessary

This interleaved control flow is because vllm used continuous batching, which will call our vllm_wrapper.forward() first, and then slicing and only get last token from each request to compute logprob. I think this behavior can not be fully avoided unless we want to unwrap the current vllm engine code and write our own.

how vllm engine is using this config? As we already construct model and apply parallelism by ourselves, can we assume vllm config is useless and we can pass in random garbage and it will get ignored?

vllm engine internally use vLLMConfig a lot, eg link. But I agree in model/parallelism part, the only important information from vllm_config is parallelism degrees, which we should also be able to get from RL orchestrator. With the RL orchestrator, we can safely assume the vllm_config in TorchTitanVLLMWrapper.__init__(*, vllm_config) (passed from vllm engine) can be ignored.

A side note is: With RL orchestrator in torchtitan, I think we only need to do one-directional conversion: torchtitan JobConfig -> vllmConfig.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by any means we should let torchtitan config be single source of truth. We probably should do:

  • torchtitan config -> vllm version of torchtitan model config -> calling titan parallelize functions
  • torchtitan config -> vllm config

We should avoid

  • torchtitan config -> vllm config -> vllm version of torchtitan model config -> calling titan parallelize functions

) -> torch.Tensor | None:
"""Compute logits from hidden states."""
if self.parallel_dims.tp_enabled:
# Turn hidden_states back to DTensor
Copy link
Contributor

@acisseJZhong acisseJZhong Dec 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to make this converting to DTensor happen in prepare_module_input for the norm layer? This avoids mixing parallelizing model with the model code itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great question, technically we could do that using PrepareModuleInput to do this. The reason I do this manually here is that I want to keep vllm relate changes within the wrapper. I want to keep the TP plan as the same as training TP plan.

There are 2 places in the wrapper I do the manual conversion between DTensor <> plain tensor:

  1. At the end of forward(): The original output of transformers is Shard(1), and I manually all-gather it to be full tensor. This result will be passed to vllm engine.
  2. At the beginning of compute_logits(): The input is a plain tensor passed from vllm engine, which is a slices from 1). As we did AG in 1), so the tensor here is Replicate()

Without these 2 conversion, the output of transformers will be Shard(1) and be passed into self.norms (which is expecting a Shard(1) DTensor as input). If we put 1) and 2) into TP plan, it will cause a uncessary AG if we apply the same plan in training.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the reason why we manually all_gather to get full tensor and later create DTensor is because vllm requires the function to return and takes in plain tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's what I learned from llama code and vllm engine code

"""
assert hasattr(
self.model, "layers"
), f"Model {type(self.model).__name__} must have .layers attribute"

tp_size = self.parallel_dims.tp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should check this when creating the model, not here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accessing tp_size here is not for checking but we are passing it into VLLMAttention initialization because we need local number of heads / num kv heads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sry I meant the check below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sure! Let me remove it to init

for layer in self.model.layers.values():
h = layer(h, rope_cache, attention_masks=None, positions=positions)

# When parallelism is applied, get full tensor before return to vLLM Engine
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on

  • what is the placement of h before converging
  • why vLLM needs (Replicate) full tensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems you didn't answer the questions.

hidden_states,
device_mesh=self.parallel_dims.get_mesh("tp"),
placements=[
Replicate(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on what happens in vllm engine, and why it should be Replicate here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same -- didn't answer the questions

@wwwjn wwwjn requested a review from tianyu-l December 24, 2025 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants