diff --git a/README.md b/README.md index f91a1e7..0b5a9a6 100644 --- a/README.md +++ b/README.md @@ -163,9 +163,14 @@ matrix deploy_applications --applications "[{'model_name': 'meta-llama/Llama-4-M # download math-500 dataset python -m matrix.scripts.hf_dataset_to_jsonl HuggingFaceH4/MATH-500 test test.jsonl -# query math-500 +# query math-500 from local jsonl matrix inference --app_name maverick-fp8 --input_jsonls test.jsonl --output_jsonl response.jsonl --batch_size=64 \ --system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800 + +# or query directly from the Hugging Face dataset +matrix inference --app_name maverick-fp8 --input_hf_dataset HuggingFaceH4/MATH-500 --hf_dataset_split test \ + --output_jsonl response.jsonl --batch_size=64 \ + --system_prompt "Please reason step by step, and put your final answer within \boxed{}." --max_tokens 30000 --text_key problem --timeout_secs 1800 ``` #### Input Format diff --git a/matrix/app_server/app_api.py b/matrix/app_server/app_api.py index f2e80c3..7c75e9e 100644 --- a/matrix/app_server/app_api.py +++ b/matrix/app_server/app_api.py @@ -406,13 +406,18 @@ def inference( self, app_name: str, output_jsonl: str, - input_jsonls: str, + input_jsonls: str | None = None, + input_hf_dataset: str | None = None, + hf_dataset_split: str = "train", load_balance: bool = True, **kwargs, ): - """Run LLM inference.""" + """Run LLM inference. - from matrix.client.query_llm import main as query + The input can be provided either as JSONL files via ``input_jsonls`` or + fetched directly from a Hugging Face dataset using ``input_hf_dataset`` + and ``hf_dataset_split``. + """ metadata = self.get_app_metadata(app_name) assert self._cluster_info.hostname @@ -448,6 +453,8 @@ async def get_one_endpoint() -> str: input_jsonls, model=metadata["model_name"], app_name=metadata["name"], + input_hf_dataset=input_hf_dataset, + hf_dataset_split=hf_dataset_split, **kwargs, ) ) @@ -455,6 +462,7 @@ async def get_one_endpoint() -> str: from matrix.client.execute_code import CodeExcutionClient client = CodeExcutionClient(get_one_endpoint) + assert input_jsonls is not None, "input_jsonls is required for code apps" return asyncio.run( client.execute_code( output_jsonl, @@ -466,6 +474,7 @@ async def get_one_endpoint() -> str: from matrix.client.process_vision_data import VisionClient vision_client = VisionClient(get_one_endpoint) + assert input_jsonls is not None, "input_jsonls is required for vision apps" return asyncio.run( vision_client.inference( output_jsonl, diff --git a/matrix/cli.py b/matrix/cli.py index 1c8128a..001c3e5 100644 --- a/matrix/cli.py +++ b/matrix/cli.py @@ -196,7 +196,15 @@ def deploy_applications( yaml_config, ) - def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwargs): + def inference( + self, + app_name: str, + output_jsonl: str, + input_jsonls: str | None = None, + input_hf_dataset: str | None = None, + hf_dataset_split: str = "train", + **kwargs, + ): """ Run batch inference using a deployed application. @@ -206,7 +214,9 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg Args: app_name (str): The name of the deployed application to use. output_jsonl (str): Path to save inference results in JSONL format. - input_jsonls (str): Path to input data in JSONL format. + input_jsonls (str | None): Path to input data in JSONL format. + input_hf_dataset (str | None): Hugging Face dataset name to load directly. + hf_dataset_split (str): Dataset split to load when using a Hugging Face dataset. **kwargs: Additional parameters for inference (e.g., temperature, max_tokens). Returns: @@ -216,6 +226,8 @@ def inference(self, app_name: str, output_jsonl: str, input_jsonls: str, **kwarg app_name, output_jsonl, input_jsonls, + input_hf_dataset=input_hf_dataset, + hf_dataset_split=hf_dataset_split, **kwargs, ) diff --git a/matrix/client/query_llm.py b/matrix/client/query_llm.py index 0d8dc88..66da9c4 100644 --- a/matrix/client/query_llm.py +++ b/matrix/client/query_llm.py @@ -65,6 +65,50 @@ def convert_llama_instruct_text( return messages +def _get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]: + keys = key.split(".") + current_data = data + for k in keys: + if isinstance(current_data, dict) and k in current_data: + current_data = current_data[k] + else: + return None + return current_data + + +def _get_metadata_key(text_key: str) -> str: + parts = text_key.split(".") + parts[-1] = "metadata" + return ".".join(parts) + + +def _prepare_request( + sample: tp.Dict[str, tp.Any], + text_key: str, + messages_key: str, + system_prompt: str, + default_metadata: tp.Dict[str, tp.Any], +) -> tp.Dict[str, tp.Any]: + text = _get_request(text_key, sample) + if text: + messages = convert_llama_instruct_text(text) + metadata = _get_request(_get_metadata_key(text_key), sample) + else: + messages = _get_request(messages_key, sample) # type: ignore + assert messages, f"either {text_key} or {messages_key} should exist" + metadata = _get_request(_get_metadata_key(messages_key), sample) + + if system_prompt: + if messages[0]["role"] == "system": + messages[0]["content"] = system_prompt + else: + messages.insert(0, {"role": "system", "content": system_prompt}) + + if metadata is None: + metadata = default_metadata + return {"metadata": metadata, "messages": messages} + + def load_from_jsonl( input_files: tp.Tuple[str, ...], text_key: str, @@ -72,47 +116,18 @@ def load_from_jsonl( system_prompt: str, ) -> tp.List[tp.Dict[str, tp.Any]]: - def get_request(key: str, data: tp.Dict[str, tp.Any]) -> tp.Optional[tp.Any]: - keys = key.split(".") - current_data = data - for k in keys: - if isinstance(current_data, dict) and k in current_data: - current_data = current_data[k] - else: - return None - return current_data - - def get_metadata_key(text_key: str) -> str: - parts = text_key.split(".") - parts[-1] = "metadata" - return ".".join(parts) - def load_json_line( - file_name: str, line: str, line_number: int, system_prompt: str + file_name: str, line: str, line_number: int ) -> tp.Dict[str, tp.Any]: try: data = json.loads(line) - text = get_request(text_key, data) - if text: - messages = convert_llama_instruct_text(text) - metadata = get_request(get_metadata_key(text_key), data) - else: - messages = get_request(messages_key, data) # type: ignore - assert messages, f"either {text_key} or {messages_key} should exist" - metadata = get_request(get_metadata_key(messages_key), data) - - if system_prompt: - if messages[0]["role"] == "system": - messages[0]["content"] = system_prompt - else: - messages.insert(0, {"role": "system", "content": system_prompt}) - - if metadata is None: - metadata = {"filename": file_name, "line": line_number} - return { - "metadata": metadata, - "messages": messages, - } + return _prepare_request( + data, + text_key, + messages_key, + system_prompt, + {"filename": file_name, "line": line_number}, + ) except Exception as e: raise ValueError(f"Error in line {line_number}\n{line} of {file_name}: {e}") @@ -126,7 +141,7 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int: max_length = 0 num_lines = 0 for num_lines, line in enumerate(f, start=1): - item = load_json_line(file_name, line, num_lines, system_prompt) + item = load_json_line(file_name, line, num_lines) max_length = max(get_text_length(item["messages"]), max_length) # Add metadata to the dictionary data.append(item) @@ -136,6 +151,31 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int: return data +def load_from_hf_dataset( + dataset_name: str, + split: str, + text_key: str, + messages_key: str, + system_prompt: str, +) -> tp.List[tp.Dict[str, tp.Any]]: + from datasets import load_dataset + + dataset = load_dataset(dataset_name, split=split) + data = [] + for idx, sample in enumerate(dataset): + data.append( + _prepare_request( + sample, + text_key, + messages_key, + system_prompt, + {"index": idx}, + ) + ) + logger.info(f"Loaded {len(data)} samples from {dataset_name} split {split}") + return data + + def _convert_token_log_probs(token_log_probs): if not token_log_probs.token_map: return None @@ -617,9 +657,9 @@ def batch_requests( async def main( url: tp.Union[str, tp.Callable[[], tp.Awaitable[str]]], output_file: str, - input_jsonls: str, - app_name: str, - model: str, + input_jsonls: str | None = None, + app_name: str = "", + model: str = "", batch_size=32, seed=42, temperature=0.7, @@ -632,6 +672,8 @@ async def main( system_prompt="", timeout_secs=600, batch_mode=False, + input_hf_dataset: str | None = None, + hf_dataset_split: str = "train", ) -> tp.Dict[str, int]: """Send jsonl llama3 instruct prompt for inference and save both the request and response as jsonl. params: @@ -640,6 +682,8 @@ async def main( input_jsonls: variable num of input jsonl files, each line is a json with two formats 1. {text_key: prompt} if text_key is found, prompt is raw text 2. {messages_key: Iterable[ChatCompletionMessageParam]} if messages_key is found. + input_hf_dataset: name of a Hugging Face dataset to load directly. + hf_dataset_split: dataset split to use when loading from Hugging Face. model: the huggingface model name or a directory. batch_size: max number of concurrent requests. seed: seed. @@ -661,17 +705,25 @@ async def main( os.makedirs(save_dir, exist_ok=True) if os.path.exists(output_file): logger.warning(f"Output file '{output_file}' already exists, overwriting...") - input_files = glob.glob(input_jsonls) - if not input_files: - logger.error(f"No input files found matching pattern: {input_jsonls}") - return {} - - lines = load_from_jsonl( - tuple(input_files), - text_key, - messages_key, - system_prompt=system_prompt, - ) + if input_hf_dataset: + lines = load_from_hf_dataset( + input_hf_dataset, + hf_dataset_split, + text_key, + messages_key, + system_prompt=system_prompt, + ) + else: + input_files = glob.glob(input_jsonls or "") + if not input_files: + logger.error(f"No input files found matching pattern: {input_jsonls}") + return {} + lines = load_from_jsonl( + tuple(input_files), + text_key, + messages_key, + system_prompt=system_prompt, + ) stats = {"success": 0, "total": 0, "sum_latency": 0} if batch_mode: outputs = await batch_requests_async( diff --git a/pyproject.toml b/pyproject.toml index 3e2457e..79727c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dynamic = ["version", "description"] dependencies = [ "psutil", + "datasets", "grpcio==1.70.0", "grpcio-tools==1.70.0", "fire", diff --git a/tests/unit/query/test_load_from_hf_dataset.py b/tests/unit/query/test_load_from_hf_dataset.py new file mode 100644 index 0000000..1a8623f --- /dev/null +++ b/tests/unit/query/test_load_from_hf_dataset.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +pytest.importorskip("datasets") +from datasets import Dataset + + +def test_load_from_hf_dataset(monkeypatch): + from matrix.client import query_llm + + dataset = Dataset.from_dict({"problem": ["1+1", "2+2"]}) + + def mock_load_dataset(*args, **kwargs): + return dataset + + monkeypatch.setattr("datasets.load_dataset", mock_load_dataset) + + lines = query_llm.load_from_hf_dataset( + "dummy", + "train", + text_key="problem", + messages_key="request.messages", + system_prompt="sys", + ) + + assert len(lines) == 2 + assert lines[0]["messages"][0]["role"] == "system" + assert lines[0]["messages"][0]["content"] == "sys" + assert lines[0]["messages"][1]["content"] == "1+1" + assert lines[0]["metadata"]["index"] == 0