Skip to content

Commit 907e0ff

Browse files
feat: load HF datasets in inference
1 parent a2a8f82 commit 907e0ff

File tree

5 files changed

+132
-18
lines changed

5 files changed

+132
-18
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,14 @@ matrix deploy_applications --applications "[{'model_name': 'meta-llama/Llama-4-M
163163
# download math-500 dataset
164164
python -m matrix.scripts.hf_dataset_to_jsonl HuggingFaceH4/MATH-500 test test.jsonl
165165

166-
# query math-500
166+
# query math-500 from local jsonl
167167
matrix inference --app_name maverick-fp8 --input_jsonls test.jsonl --output_jsonl response.jsonl --batch_size=64 \
168168
--system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800
169+
170+
# or query directly from the Hugging Face dataset
171+
matrix inference --app_name maverick-fp8 --input_hf_dataset HuggingFaceH4/MATH-500 --hf_dataset_split test \
172+
--output_jsonl response.jsonl --batch_size=64 \
173+
--system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800
169174
```
170175

171176
#### Input Format

matrix/app_server/app_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,9 @@ def inference(
406406
self,
407407
app_name: str,
408408
output_jsonl: str,
409-
input_jsonls: str,
409+
input_jsonls: str | None = None,
410+
input_hf_dataset: str | None = None,
411+
hf_dataset_split: str = "train",
410412
load_balance: bool = True,
411413
**kwargs,
412414
):
@@ -448,6 +450,8 @@ async def get_one_endpoint() -> str:
448450
input_jsonls,
449451
model=metadata["model_name"],
450452
app_name=metadata["name"],
453+
input_hf_dataset=input_hf_dataset,
454+
hf_dataset_split=hf_dataset_split,
451455
**kwargs,
452456
)
453457
)

matrix/cli.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,15 @@ def deploy_applications(
196196
yaml_config,
197197
)
198198

199-
def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwargs):
199+
def inference(
200+
self,
201+
app_name: str,
202+
output_jsonl: str,
203+
input_jsonls: str | None = None,
204+
input_hf_dataset: str | None = None,
205+
hf_dataset_split: str = "train",
206+
**kwargs,
207+
):
200208
"""
201209
Run batch inference using a deployed application.
202210
@@ -206,7 +214,9 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg
206214
Args:
207215
app_name (str): The name of the deployed application to use.
208216
output_jsonl (str): Path to save inference results in JSONL format.
209-
input_jsonls (str): Path to input data in JSONL format.
217+
input_jsonls (str | None): Path to input data in JSONL format.
218+
input_hf_dataset (str | None): Hugging Face dataset name to load directly.
219+
hf_dataset_split (str): Dataset split to load when using a Hugging Face dataset.
210220
**kwargs: Additional parameters for inference (e.g., temperature, max_tokens).
211221
212222
Returns:
@@ -216,6 +226,8 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg
216226
app_name,
217227
output_jsonl,
218228
input_jsonls,
229+
input_hf_dataset=input_hf_dataset,
230+
hf_dataset_split=hf_dataset_split,
219231
**kwargs,
220232
)
221233

matrix/client/query_llm.py

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,55 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
136136
return data
137137

138138

139+
def load_from_hf_dataset(
140+
dataset_name: str,
141+
split: str,
142+
text_key: str,
143+
messages_key: str,
144+
system_prompt: str,
145+
) -> tp.List[tp.Dict[str, tp.Any]]:
146+
from datasets import load_dataset
147+
148+
def get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]:
149+
keys = key.split(".")
150+
current_data = data
151+
for k in keys:
152+
if isinstance(current_data, dict) and k in current_data:
153+
current_data = current_data[k]
154+
else:
155+
return None
156+
return current_data
157+
158+
def get_metadata_key(text_key: str) -> str:
159+
parts = text_key.split(".")
160+
parts[-1] = "metadata"
161+
return ".".join(parts)
162+
163+
dataset = load_dataset(dataset_name, split=split)
164+
data = []
165+
for idx, sample in enumerate(dataset):
166+
text = get_request(text_key, sample)
167+
if text:
168+
messages = convert_llama_instruct_text(text)
169+
metadata = get_request(get_metadata_key(text_key), sample)
170+
else:
171+
messages = get_request(messages_key, sample) # type: ignore
172+
assert messages, f"either {text_key} or {messages_key} should exist"
173+
metadata = get_request(get_metadata_key(messages_key), sample)
174+
175+
if system_prompt:
176+
if messages[0]["role"] == "system":
177+
messages[0]["content"] = system_prompt
178+
else:
179+
messages.insert(0, {"role": "system", "content": system_prompt})
180+
181+
if metadata is None:
182+
metadata = {"index": idx}
183+
data.append({"metadata": metadata, "messages": messages})
184+
logger.info(f"Loaded {len(data)} samples from {dataset_name} split {split}")
185+
return data
186+
187+
139188
def _convert_token_log_probs(token_log_probs):
140189
if not token_log_probs.token_map:
141190
return None
@@ -617,9 +666,9 @@ def batch_requests(
617666
async def main(
618667
url: tp.Union[str, tp.Callable[[], tp.Awaitable[str]]],
619668
output_file: str,
620-
input_jsonls: str,
621-
app_name: str,
622-
model: str,
669+
input_jsonls: str | None = None,
670+
app_name: str = "",
671+
model: str = "",
623672
batch_size=32,
624673
seed=42,
625674
temperature=0.7,
@@ -632,6 +681,8 @@ async def main(
632681
system_prompt="",
633682
timeout_secs=600,
634683
batch_mode=False,
684+
input_hf_dataset: str | None = None,
685+
hf_dataset_split: str = "train",
635686
) -> tp.Dict[str, int]:
636687
"""Send jsonl llama3 instruct prompt for inference and save both the request and response as jsonl.
637688
params:
@@ -640,6 +691,8 @@ async def main(
640691
input_jsonls: variable num of input jsonl files, each line is a json with two formats
641692
1. {text_key: prompt} if text_key is found, prompt is raw text
642693
2. {messages_key: Iterable[ChatCompletionMessageParam]} if messages_key is found.
694+
input_hf_dataset: name of a Hugging Face dataset to load directly.
695+
hf_dataset_split: dataset split to use when loading from Hugging Face.
643696
model: the huggingface model name or a directory.
644697
batch_size: max number of concurrent requests.
645698
seed: seed.
@@ -661,17 +714,25 @@ async def main(
661714
os.makedirs(save_dir, exist_ok=True)
662715
if os.path.exists(output_file):
663716
logger.warning(f"Output file '{output_file}' already exists, overwriting...")
664-
input_files = glob.glob(input_jsonls)
665-
if not input_files:
666-
logger.error(f"No input files found matching pattern: {input_jsonls}")
667-
return {}
668-
669-
lines = load_from_jsonl(
670-
tuple(input_files),
671-
text_key,
672-
messages_key,
673-
system_prompt=system_prompt,
674-
)
717+
if input_hf_dataset:
718+
lines = load_from_hf_dataset(
719+
input_hf_dataset,
720+
hf_dataset_split,
721+
text_key,
722+
messages_key,
723+
system_prompt=system_prompt,
724+
)
725+
else:
726+
input_files = glob.glob(input_jsonls or "")
727+
if not input_files:
728+
logger.error(f"No input files found matching pattern: {input_jsonls}")
729+
return {}
730+
lines = load_from_jsonl(
731+
tuple(input_files),
732+
text_key,
733+
messages_key,
734+
system_prompt=system_prompt,
735+
)
675736
stats = {"success": 0, "total": 0, "sum_latency": 0}
676737
if batch_mode:
677738
outputs = await batch_requests_async(
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from datasets import Dataset
8+
9+
10+
def test_load_from_hf_dataset(monkeypatch):
11+
from matrix.client import query_llm
12+
13+
dataset = Dataset.from_dict({"problem": ["1+1", "2+2"]})
14+
15+
def mock_load_dataset(*args, **kwargs):
16+
return dataset
17+
18+
monkeypatch.setattr("datasets.load_dataset", mock_load_dataset)
19+
20+
lines = query_llm.load_from_hf_dataset(
21+
"dummy",
22+
"train",
23+
text_key="problem",
24+
messages_key="request.messages",
25+
system_prompt="sys",
26+
)
27+
28+
assert len(lines) == 2
29+
assert lines[0]["messages"][0]["role"] == "system"
30+
assert lines[0]["messages"][0]["content"] == "sys"
31+
assert lines[0]["messages"][1]["content"] == "1+1"
32+
assert lines[0]["metadata"]["index"] == 0

0 commit comments

Comments
 (0)