Skip to content

Commit dfde144

Browse files
feat: load HF datasets in inference (#80)
* feat: load HF datasets in inference * feat: load HF datasets in inference * fix lints * refactor: share dataset parsing
1 parent 8c43159 commit dfde144

File tree

6 files changed

+172
-58
lines changed

6 files changed

+172
-58
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,14 @@ matrix deploy_applications --applications "[{'model_name': 'meta-llama/Llama-4-M
163163
# download math-500 dataset
164164
python -m matrix.scripts.hf_dataset_to_jsonl HuggingFaceH4/MATH-500 test test.jsonl
165165

166-
# query math-500
166+
# query math-500 from local jsonl
167167
matrix inference --app_name maverick-fp8 --input_jsonls test.jsonl --output_jsonl response.jsonl --batch_size=64 \
168168
--system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800
169+
170+
# or query directly from the Hugging Face dataset
171+
matrix inference --app_name maverick-fp8 --input_hf_dataset HuggingFaceH4/MATH-500 --hf_dataset_split test \
172+
--output_jsonl response.jsonl --batch_size=64 \
173+
--system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800
169174
```
170175

171176
#### Input Format

matrix/app_server/app_api.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,18 @@ def inference(
406406
self,
407407
app_name: str,
408408
output_jsonl: str,
409-
input_jsonls: str,
409+
input_jsonls: str | None = None,
410+
input_hf_dataset: str | None = None,
411+
hf_dataset_split: str = "train",
410412
load_balance: bool = True,
411413
**kwargs,
412414
):
413-
"""Run LLM inference."""
415+
"""Run LLM inference.
414416
415-
from matrix.client.query_llm import main as query
417+
The input can be provided either as JSONL files via ``input_jsonls`` or
418+
fetched directly from a Hugging Face dataset using ``input_hf_dataset``
419+
and ``hf_dataset_split``.
420+
"""
416421

417422
metadata = self.get_app_metadata(app_name)
418423
assert self._cluster_info.hostname
@@ -448,13 +453,16 @@ async def get_one_endpoint() -> str:
448453
input_jsonls,
449454
model=metadata["model_name"],
450455
app_name=metadata["name"],
456+
input_hf_dataset=input_hf_dataset,
457+
hf_dataset_split=hf_dataset_split,
451458
**kwargs,
452459
)
453460
)
454461
elif app_type == "code":
455462
from matrix.client.execute_code import CodeExcutionClient
456463

457464
client = CodeExcutionClient(get_one_endpoint)
465+
assert input_jsonls is not None, "input_jsonls is required for code apps"
458466
return asyncio.run(
459467
client.execute_code(
460468
output_jsonl,
@@ -466,6 +474,7 @@ async def get_one_endpoint() -> str:
466474
from matrix.client.process_vision_data import VisionClient
467475

468476
vision_client = VisionClient(get_one_endpoint)
477+
assert input_jsonls is not None, "input_jsonls is required for vision apps"
469478
return asyncio.run(
470479
vision_client.inference(
471480
output_jsonl,

matrix/cli.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,15 @@ def deploy_applications(
196196
yaml_config,
197197
)
198198

199-
def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwargs):
199+
def inference(
200+
self,
201+
app_name: str,
202+
output_jsonl: str,
203+
input_jsonls: str | None = None,
204+
input_hf_dataset: str | None = None,
205+
hf_dataset_split: str = "train",
206+
**kwargs,
207+
):
200208
"""
201209
Run batch inference using a deployed application.
202210
@@ -206,7 +214,9 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg
206214
Args:
207215
app_name (str): The name of the deployed application to use.
208216
output_jsonl (str): Path to save inference results in JSONL format.
209-
input_jsonls (str): Path to input data in JSONL format.
217+
input_jsonls (str | None): Path to input data in JSONL format.
218+
input_hf_dataset (str | None): Hugging Face dataset name to load directly.
219+
hf_dataset_split (str): Dataset split to load when using a Hugging Face dataset.
210220
**kwargs: Additional parameters for inference (e.g., temperature, max_tokens).
211221
212222
Returns:
@@ -216,6 +226,8 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg
216226
app_name,
217227
output_jsonl,
218228
input_jsonls,
229+
input_hf_dataset=input_hf_dataset,
230+
hf_dataset_split=hf_dataset_split,
219231
**kwargs,
220232
)
221233

matrix/client/query_llm.py

Lines changed: 104 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -65,54 +65,69 @@ def convert_llama_instruct_text(
6565
return messages
6666

6767

68+
def _get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]:
69+
keys = key.split(".")
70+
current_data = data
71+
for k in keys:
72+
if isinstance(current_data, dict) and k in current_data:
73+
current_data = current_data[k]
74+
else:
75+
return None
76+
return current_data
77+
78+
79+
def _get_metadata_key(text_key: str) -> str:
80+
parts = text_key.split(".")
81+
parts[-1] = "metadata"
82+
return ".".join(parts)
83+
84+
85+
def _prepare_request(
86+
sample: tp.Dict[str, tp.Any],
87+
text_key: str,
88+
messages_key: str,
89+
system_prompt: str,
90+
default_metadata: tp.Dict[str, tp.Any],
91+
) -> tp.Dict[str, tp.Any]:
92+
text = _get_request(text_key, sample)
93+
if text:
94+
messages = convert_llama_instruct_text(text)
95+
metadata = _get_request(_get_metadata_key(text_key), sample)
96+
else:
97+
messages = _get_request(messages_key, sample) # type: ignore
98+
assert messages, f"either {text_key} or {messages_key} should exist"
99+
metadata = _get_request(_get_metadata_key(messages_key), sample)
100+
101+
if system_prompt:
102+
if messages[0]["role"] == "system":
103+
messages[0]["content"] = system_prompt
104+
else:
105+
messages.insert(0, {"role": "system", "content": system_prompt})
106+
107+
if metadata is None:
108+
metadata = default_metadata
109+
return {"metadata": metadata, "messages": messages}
110+
111+
68112
def load_from_jsonl(
69113
input_files: tp.Tuple[str, ...],
70114
text_key: str,
71115
messages_key: str,
72116
system_prompt: str,
73117
) -> tp.List[tp.Dict[str, tp.Any]]:
74118

75-
def get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]:
76-
keys = key.split(".")
77-
current_data = data
78-
for k in keys:
79-
if isinstance(current_data, dict) and k in current_data:
80-
current_data = current_data[k]
81-
else:
82-
return None
83-
return current_data
84-
85-
def get_metadata_key(text_key: str) -> str:
86-
parts = text_key.split(".")
87-
parts[-1] = "metadata"
88-
return ".".join(parts)
89-
90119
def load_json_line(
91-
file_name: str, line: str, line_number: int, system_prompt: str
120+
file_name: str, line: str, line_number: int
92121
) -> tp.Dict[str, tp.Any]:
93122
try:
94123
data = json.loads(line)
95-
text = get_request(text_key, data)
96-
if text:
97-
messages = convert_llama_instruct_text(text)
98-
metadata = get_request(get_metadata_key(text_key), data)
99-
else:
100-
messages = get_request(messages_key, data) # type: ignore
101-
assert messages, f"either {text_key} or {messages_key} should exist"
102-
metadata = get_request(get_metadata_key(messages_key), data)
103-
104-
if system_prompt:
105-
if messages[0]["role"] == "system":
106-
messages[0]["content"] = system_prompt
107-
else:
108-
messages.insert(0, {"role": "system", "content": system_prompt})
109-
110-
if metadata is None:
111-
metadata = {"filename": file_name, "line": line_number}
112-
return {
113-
"metadata": metadata,
114-
"messages": messages,
115-
}
124+
return _prepare_request(
125+
data,
126+
text_key,
127+
messages_key,
128+
system_prompt,
129+
{"filename": file_name, "line": line_number},
130+
)
116131
except Exception as e:
117132
raise ValueError(f"Error in line {line_number}\n{line} of {file_name}: {e}")
118133

@@ -126,7 +141,7 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
126141
max_length = 0
127142
num_lines = 0
128143
for num_lines, line in enumerate(f, start=1):
129-
item = load_json_line(file_name, line, num_lines, system_prompt)
144+
item = load_json_line(file_name, line, num_lines)
130145
max_length = max(get_text_length(item["messages"]), max_length)
131146
# Add metadata to the dictionary
132147
data.append(item)
@@ -136,6 +151,31 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
136151
return data
137152

138153

154+
def load_from_hf_dataset(
155+
dataset_name: str,
156+
split: str,
157+
text_key: str,
158+
messages_key: str,
159+
system_prompt: str,
160+
) -> tp.List[tp.Dict[str, tp.Any]]:
161+
from datasets import load_dataset
162+
163+
dataset = load_dataset(dataset_name, split=split)
164+
data = []
165+
for idx, sample in enumerate(dataset):
166+
data.append(
167+
_prepare_request(
168+
sample,
169+
text_key,
170+
messages_key,
171+
system_prompt,
172+
{"index": idx},
173+
)
174+
)
175+
logger.info(f"Loaded {len(data)} samples from {dataset_name} split {split}")
176+
return data
177+
178+
139179
def _convert_token_log_probs(token_log_probs):
140180
if not token_log_probs.token_map:
141181
return None
@@ -617,9 +657,9 @@ def batch_requests(
617657
async def main(
618658
url: tp.Union[str, tp.Callable[[], tp.Awaitable[str]]],
619659
output_file: str,
620-
input_jsonls: str,
621-
app_name: str,
622-
model: str,
660+
input_jsonls: str | None = None,
661+
app_name: str = "",
662+
model: str = "",
623663
batch_size=32,
624664
seed=42,
625665
temperature=0.7,
@@ -632,6 +672,8 @@ async def main(
632672
system_prompt="",
633673
timeout_secs=600,
634674
batch_mode=False,
675+
input_hf_dataset: str | None = None,
676+
hf_dataset_split: str = "train",
635677
) -> tp.Dict[str, int]:
636678
"""Send jsonl llama3 instruct prompt for inference and save both the request and response as jsonl.
637679
params:
@@ -640,6 +682,8 @@ async def main(
640682
input_jsonls: variable num of input jsonl files, each line is a json with two formats
641683
1. {text_key: prompt} if text_key is found, prompt is raw text
642684
2. {messages_key: Iterable[ChatCompletionMessageParam]} if messages_key is found.
685+
input_hf_dataset: name of a Hugging Face dataset to load directly.
686+
hf_dataset_split: dataset split to use when loading from Hugging Face.
643687
model: the huggingface model name or a directory.
644688
batch_size: max number of concurrent requests.
645689
seed: seed.
@@ -661,17 +705,25 @@ async def main(
661705
os.makedirs(save_dir, exist_ok=True)
662706
if os.path.exists(output_file):
663707
logger.warning(f"Output file '{output_file}' already exists, overwriting...")
664-
input_files = glob.glob(input_jsonls)
665-
if not input_files:
666-
logger.error(f"No input files found matching pattern: {input_jsonls}")
667-
return {}
668-
669-
lines = load_from_jsonl(
670-
tuple(input_files),
671-
text_key,
672-
messages_key,
673-
system_prompt=system_prompt,
674-
)
708+
if input_hf_dataset:
709+
lines = load_from_hf_dataset(
710+
input_hf_dataset,
711+
hf_dataset_split,
712+
text_key,
713+
messages_key,
714+
system_prompt=system_prompt,
715+
)
716+
else:
717+
input_files = glob.glob(input_jsonls or "")
718+
if not input_files:
719+
logger.error(f"No input files found matching pattern: {input_jsonls}")
720+
return {}
721+
lines = load_from_jsonl(
722+
tuple(input_files),
723+
text_key,
724+
messages_key,
725+
system_prompt=system_prompt,
726+
)
675727
stats = {"success": 0, "total": 0, "sum_latency": 0}
676728
if batch_mode:
677729
outputs = await batch_requests_async(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dynamic = ["version", "description"]
1111

1212
dependencies = [
1313
"psutil",
14+
"datasets",
1415
"grpcio==1.70.0",
1516
"grpcio-tools==1.70.0",
1617
"fire",
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
9+
pytest.importorskip("datasets")
10+
from datasets import Dataset
11+
12+
13+
def test_load_from_hf_dataset(monkeypatch):
14+
from matrix.client import query_llm
15+
16+
dataset = Dataset.from_dict({"problem": ["1+1", "2+2"]})
17+
18+
def mock_load_dataset(*args, **kwargs):
19+
return dataset
20+
21+
monkeypatch.setattr("datasets.load_dataset", mock_load_dataset)
22+
23+
lines = query_llm.load_from_hf_dataset(
24+
"dummy",
25+
"train",
26+
text_key="problem",
27+
messages_key="request.messages",
28+
system_prompt="sys",
29+
)
30+
31+
assert len(lines) == 2
32+
assert lines[0]["messages"][0]["role"] == "system"
33+
assert lines[0]["messages"][0]["content"] == "sys"
34+
assert lines[0]["messages"][1]["content"] == "1+1"
35+
assert lines[0]["metadata"]["index"] == 0

0 commit comments

Comments
 (0)