Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,14 @@ matrix deploy_applications --applications "[{'model_name': 'meta-llama/Llama-4-M
# download math-500 dataset
python -m matrix.scripts.hf_dataset_to_jsonl HuggingFaceH4/MATH-500 test test.jsonl

# query math-500
# query math-500 from local jsonl
matrix inference --app_name maverick-fp8 --input_jsonls test.jsonl --output_jsonl response.jsonl --batch_size=64 \
--system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800

# or query directly from the Hugging Face dataset
matrix inference --app_name maverick-fp8 --input_hf_dataset HuggingFaceH4/MATH-500 --hf_dataset_split test \
--output_jsonl response.jsonl --batch_size=64 \
--system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800
```

#### Input Format
Expand Down
15 changes: 12 additions & 3 deletions matrix/app_server/app_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,18 @@ def inference(
self,
app_name: str,
output_jsonl: str,
input_jsonls: str,
input_jsonls: str | None = None,
input_hf_dataset: str | None = None,
hf_dataset_split: str = "train",
load_balance: bool = True,
**kwargs,
):
"""Run LLM inference."""
"""Run LLM inference.

from matrix.client.query_llm import main as query
The input can be provided either as JSONL files via ``input_jsonls`` or
fetched directly from a Hugging Face dataset using ``input_hf_dataset``
and ``hf_dataset_split``.
"""

metadata = self.get_app_metadata(app_name)
assert self._cluster_info.hostname
Expand Down Expand Up @@ -448,13 +453,16 @@ async def get_one_endpoint() -> str:
input_jsonls,
model=metadata["model_name"],
app_name=metadata["name"],
input_hf_dataset=input_hf_dataset,
hf_dataset_split=hf_dataset_split,
**kwargs,
)
)
elif app_type == "code":
from matrix.client.execute_code import CodeExcutionClient

client = CodeExcutionClient(get_one_endpoint)
assert input_jsonls is not None, "input_jsonls is required for code apps"
return asyncio.run(
client.execute_code(
output_jsonl,
Expand All @@ -466,6 +474,7 @@ async def get_one_endpoint() -> str:
from matrix.client.process_vision_data import VisionClient

vision_client = VisionClient(get_one_endpoint)
assert input_jsonls is not None, "input_jsonls is required for vision apps"
return asyncio.run(
vision_client.inference(
output_jsonl,
Expand Down
16 changes: 14 additions & 2 deletions matrix/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,15 @@ def deploy_applications(
yaml_config,
)

def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwargs):
def inference(
self,
app_name: str,
output_jsonl: str,
input_jsonls: str | None = None,
input_hf_dataset: str | None = None,
hf_dataset_split: str = "train",
**kwargs,
):
"""
Run batch inference using a deployed application.

Expand All @@ -206,7 +214,9 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg
Args:
app_name (str): The name of the deployed application to use.
output_jsonl (str): Path to save inference results in JSONL format.
input_jsonls (str): Path to input data in JSONL format.
input_jsonls (str | None): Path to input data in JSONL format.
input_hf_dataset (str | None): Hugging Face dataset name to load directly.
hf_dataset_split (str): Dataset split to load when using a Hugging Face dataset.
**kwargs: Additional parameters for inference (e.g., temperature, max_tokens).

Returns:
Expand All @@ -216,6 +226,8 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg
app_name,
output_jsonl,
input_jsonls,
input_hf_dataset=input_hf_dataset,
hf_dataset_split=hf_dataset_split,
**kwargs,
)

Expand Down
156 changes: 104 additions & 52 deletions matrix/client/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,54 +65,69 @@ def convert_llama_instruct_text(
return messages


def _get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]:
keys = key.split(".")
current_data = data
for k in keys:
if isinstance(current_data, dict) and k in current_data:
current_data = current_data[k]
else:
return None
return current_data


def _get_metadata_key(text_key: str) -> str:
parts = text_key.split(".")
parts[-1] = "metadata"
return ".".join(parts)


def _prepare_request(
sample: tp.Dict[str, tp.Any],
text_key: str,
messages_key: str,
system_prompt: str,
default_metadata: tp.Dict[str, tp.Any],
) -> tp.Dict[str, tp.Any]:
text = _get_request(text_key, sample)
if text:
messages = convert_llama_instruct_text(text)
metadata = _get_request(_get_metadata_key(text_key), sample)
else:
messages = _get_request(messages_key, sample) # type: ignore
assert messages, f"either {text_key} or {messages_key} should exist"
metadata = _get_request(_get_metadata_key(messages_key), sample)

if system_prompt:
if messages[0]["role"] == "system":
messages[0]["content"] = system_prompt
else:
messages.insert(0, {"role": "system", "content": system_prompt})

if metadata is None:
metadata = default_metadata
return {"metadata": metadata, "messages": messages}


def load_from_jsonl(
input_files: tp.Tuple[str, ...],
text_key: str,
messages_key: str,
system_prompt: str,
) -> tp.List[tp.Dict[str, tp.Any]]:

def get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]:
keys = key.split(".")
current_data = data
for k in keys:
if isinstance(current_data, dict) and k in current_data:
current_data = current_data[k]
else:
return None
return current_data

def get_metadata_key(text_key: str) -> str:
parts = text_key.split(".")
parts[-1] = "metadata"
return ".".join(parts)

def load_json_line(
file_name: str, line: str, line_number: int, system_prompt: str
file_name: str, line: str, line_number: int
) -> tp.Dict[str, tp.Any]:
try:
data = json.loads(line)
text = get_request(text_key, data)
if text:
messages = convert_llama_instruct_text(text)
metadata = get_request(get_metadata_key(text_key), data)
else:
messages = get_request(messages_key, data) # type: ignore
assert messages, f"either {text_key} or {messages_key} should exist"
metadata = get_request(get_metadata_key(messages_key), data)

if system_prompt:
if messages[0]["role"] == "system":
messages[0]["content"] = system_prompt
else:
messages.insert(0, {"role": "system", "content": system_prompt})

if metadata is None:
metadata = {"filename": file_name, "line": line_number}
return {
"metadata": metadata,
"messages": messages,
}
return _prepare_request(
data,
text_key,
messages_key,
system_prompt,
{"filename": file_name, "line": line_number},
)
except Exception as e:
raise ValueError(f"Error in line {line_number}\n{line} of {file_name}: {e}")

Expand All @@ -126,7 +141,7 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
max_length = 0
num_lines = 0
for num_lines, line in enumerate(f, start=1):
item = load_json_line(file_name, line, num_lines, system_prompt)
item = load_json_line(file_name, line, num_lines)
max_length = max(get_text_length(item["messages"]), max_length)
# Add metadata to the dictionary
data.append(item)
Expand All @@ -136,6 +151,31 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
return data


def load_from_hf_dataset(
dataset_name: str,
split: str,
text_key: str,
messages_key: str,
system_prompt: str,
) -> tp.List[tp.Dict[str, tp.Any]]:
from datasets import load_dataset

dataset = load_dataset(dataset_name, split=split)
data = []
for idx, sample in enumerate(dataset):
data.append(
_prepare_request(
sample,
text_key,
messages_key,
system_prompt,
{"index": idx},
)
)
logger.info(f"Loaded {len(data)} samples from {dataset_name} split {split}")
return data


def _convert_token_log_probs(token_log_probs):
if not token_log_probs.token_map:
return None
Expand Down Expand Up @@ -617,9 +657,9 @@ def batch_requests(
async def main(
url: tp.Union[str, tp.Callable[[], tp.Awaitable[str]]],
output_file: str,
input_jsonls: str,
app_name: str,
model: str,
input_jsonls: str | None = None,
app_name: str = "",
model: str = "",
batch_size=32,
seed=42,
temperature=0.7,
Expand All @@ -632,6 +672,8 @@ async def main(
system_prompt="",
timeout_secs=600,
batch_mode=False,
input_hf_dataset: str | None = None,
hf_dataset_split: str = "train",
) -> tp.Dict[str, int]:
"""Send jsonl llama3 instruct prompt for inference and save both the request and response as jsonl.
params:
Expand All @@ -640,6 +682,8 @@ async def main(
input_jsonls: variable num of input jsonl files, each line is a json with two formats
1. {text_key: prompt} if text_key is found, prompt is raw text
2. {messages_key: Iterable[ChatCompletionMessageParam]} if messages_key is found.
input_hf_dataset: name of a Hugging Face dataset to load directly.
hf_dataset_split: dataset split to use when loading from Hugging Face.
model: the huggingface model name or a directory.
batch_size: max number of concurrent requests.
seed: seed.
Expand All @@ -661,17 +705,25 @@ async def main(
os.makedirs(save_dir, exist_ok=True)
if os.path.exists(output_file):
logger.warning(f"Output file '{output_file}' already exists, overwriting...")
input_files = glob.glob(input_jsonls)
if not input_files:
logger.error(f"No input files found matching pattern: {input_jsonls}")
return {}

lines = load_from_jsonl(
tuple(input_files),
text_key,
messages_key,
system_prompt=system_prompt,
)
if input_hf_dataset:
lines = load_from_hf_dataset(
input_hf_dataset,
hf_dataset_split,
text_key,
messages_key,
system_prompt=system_prompt,
)
else:
input_files = glob.glob(input_jsonls or "")
if not input_files:
logger.error(f"No input files found matching pattern: {input_jsonls}")
return {}
lines = load_from_jsonl(
tuple(input_files),
text_key,
messages_key,
system_prompt=system_prompt,
)
stats = {"success": 0, "total": 0, "sum_latency": 0}
if batch_mode:
outputs = await batch_requests_async(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dynamic = ["version", "description"]

dependencies = [
"psutil",
"datasets",
"grpcio==1.70.0",
"grpcio-tools==1.70.0",
"fire",
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/query/test_load_from_hf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import pytest

pytest.importorskip("datasets")
from datasets import Dataset


def test_load_from_hf_dataset(monkeypatch):
from matrix.client import query_llm

dataset = Dataset.from_dict({"problem": ["1+1", "2+2"]})

def mock_load_dataset(*args, **kwargs):
return dataset

monkeypatch.setattr("datasets.load_dataset", mock_load_dataset)

lines = query_llm.load_from_hf_dataset(
"dummy",
"train",
text_key="problem",
messages_key="request.messages",
system_prompt="sys",
)

assert len(lines) == 2
assert lines[0]["messages"][0]["role"] == "system"
assert lines[0]["messages"][0]["content"] == "sys"
assert lines[0]["messages"][1]["content"] == "1+1"
assert lines[0]["metadata"]["index"] == 0