-
Couldn't load subscription status.
- Fork 221
Description
Describe the bug
I'm trying to do some simple categorize with TextGeneration, however I've noticed that it always process the batch with a fixed interval (like 10 minutes). It is super weird because I'm using a relatively small model on 8xH100, the throughput should be max at 10000+ token/s, so it should takes no times to process a single batch (512 short requests). Another evidence is that the GPU utilization are 0% almost all the time, which means they are not being used properly.
So I want to know if this is a bug, or am I setting something wrong about TextGeneration? Very desperate for help.
To reproduce
Some unrelated code omitted:
def generate_category(
model: str,
*,
input_mappings: dict[str, str] = {},
llm_config: LLMConfig = SyntheticLLMConfig(),
sampling_params: SamplingParams = SamplingParams(),
**kwargs,
):
template = load_template_text(main_module, "categorize")
llm = load_distilabel_vllm()
return TextGeneration(
llm=llm,
template=template,
columns=["spec", "code"],
input_mappings=input_mappings,
resources=StepResources(gpus=llm_config.tensor_parallel_size),
**kwargs,
)
def run(
model: str,
input_path: Path,
output_path: Path,
*,
input_mappings: dict[str, str] = {},
split: str = "train",
llm_config: LLMConfig = SyntheticLLMConfig(),
sampling_params: SamplingParams = SamplingParams(),
input_batch_size: int = 50,
):
with Pipeline("categorize") as pipeline:
load = LoadDataFromHub(
repo_id=str(input_path),
split=split,
batch_size=input_batch_size,
)
generate = generate_category(
model=model,
input_mappings=input_mappings,
llm_config=llm_config,
sampling_params=sampling_params,
input_batch_size=input_batch_size,
)
parse = parse_category(input_batch_size=input_batch_size)
_ = load >> generate >> parse
distiset = pipeline.run(use_cache=False)
distiset.save_to_disk(output_path)
run(
model="models/Qwen/Qwen2.5-32B-Instruct",
split="train",
input_mappings={"spec": "detailed_global_summary", "code": "code"},
input_batch_size=512,
)Expected behavior
The process speed of TextGeneration should match the GPU throughput.
Screenshots
Environment
Version info:
requires-python = ">=3.12"
dependencies = [
"distilabel[outlines,sglang,vllm]>=1.5.3",
]
distilabel = { git = "https://github.com/argilla-io/distilabel.git", branch = "develop" }
Additional context
No response
