@@ -136,6 +136,55 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
136136 return data
137137
138138
139+ def load_from_hf_dataset (
140+ dataset_name : str ,
141+ split : str ,
142+ text_key : str ,
143+ messages_key : str ,
144+ system_prompt : str ,
145+ ) -> tp .List [tp .Dict [str , tp .Any ]]:
146+ from datasets import load_dataset
147+
148+ def get_request (key : str , data : tp .Dict [str , tp .Any ]) -> tp .Optional [tp .Any ]:
149+ keys = key .split ("." )
150+ current_data = data
151+ for k in keys :
152+ if isinstance (current_data , dict ) and k in current_data :
153+ current_data = current_data [k ]
154+ else :
155+ return None
156+ return current_data
157+
158+ def get_metadata_key (text_key : str ) -> str :
159+ parts = text_key .split ("." )
160+ parts [- 1 ] = "metadata"
161+ return "." .join (parts )
162+
163+ dataset = load_dataset (dataset_name , split = split )
164+ data = []
165+ for idx , sample in enumerate (dataset ):
166+ text = get_request (text_key , sample )
167+ if text :
168+ messages = convert_llama_instruct_text (text )
169+ metadata = get_request (get_metadata_key (text_key ), sample )
170+ else :
171+ messages = get_request (messages_key , sample ) # type: ignore
172+ assert messages , f"either { text_key } or { messages_key } should exist"
173+ metadata = get_request (get_metadata_key (messages_key ), sample )
174+
175+ if system_prompt :
176+ if messages [0 ]["role" ] == "system" :
177+ messages [0 ]["content" ] = system_prompt
178+ else :
179+ messages .insert (0 , {"role" : "system" , "content" : system_prompt })
180+
181+ if metadata is None :
182+ metadata = {"index" : idx }
183+ data .append ({"metadata" : metadata , "messages" : messages })
184+ logger .info (f"Loaded { len (data )} samples from { dataset_name } split { split } " )
185+ return data
186+
187+
139188def _convert_token_log_probs (token_log_probs ):
140189 if not token_log_probs .token_map :
141190 return None
@@ -617,9 +666,9 @@ def batch_requests(
617666async def main (
618667 url : tp .Union [str , tp .Callable [[], tp .Awaitable [str ]]],
619668 output_file : str ,
620- input_jsonls : str ,
621- app_name : str ,
622- model : str ,
669+ input_jsonls : str | None = None ,
670+ app_name : str = "" ,
671+ model : str = "" ,
623672 batch_size = 32 ,
624673 seed = 42 ,
625674 temperature = 0.7 ,
@@ -632,6 +681,8 @@ async def main(
632681 system_prompt = "" ,
633682 timeout_secs = 600 ,
634683 batch_mode = False ,
684+ input_hf_dataset : str | None = None ,
685+ hf_dataset_split : str = "train" ,
635686) -> tp .Dict [str , int ]:
636687 """Send jsonl llama3 instruct prompt for inference and save both the request and response as jsonl.
637688 params:
@@ -640,6 +691,8 @@ async def main(
640691 input_jsonls: variable num of input jsonl files, each line is a json with two formats
641692 1. {text_key: prompt} if text_key is found, prompt is raw text
642693 2. {messages_key: Iterable[ChatCompletionMessageParam]} if messages_key is found.
694+ input_hf_dataset: name of a Hugging Face dataset to load directly.
695+ hf_dataset_split: dataset split to use when loading from Hugging Face.
643696 model: the huggingface model name or a directory.
644697 batch_size: max number of concurrent requests.
645698 seed: seed.
@@ -661,17 +714,25 @@ async def main(
661714 os .makedirs (save_dir , exist_ok = True )
662715 if os .path .exists (output_file ):
663716 logger .warning (f"Output file '{ output_file } ' already exists, overwriting..." )
664- input_files = glob .glob (input_jsonls )
665- if not input_files :
666- logger .error (f"No input files found matching pattern: { input_jsonls } " )
667- return {}
668-
669- lines = load_from_jsonl (
670- tuple (input_files ),
671- text_key ,
672- messages_key ,
673- system_prompt = system_prompt ,
674- )
717+ if input_hf_dataset :
718+ lines = load_from_hf_dataset (
719+ input_hf_dataset ,
720+ hf_dataset_split ,
721+ text_key ,
722+ messages_key ,
723+ system_prompt = system_prompt ,
724+ )
725+ else :
726+ input_files = glob .glob (input_jsonls or "" )
727+ if not input_files :
728+ logger .error (f"No input files found matching pattern: { input_jsonls } " )
729+ return {}
730+ lines = load_from_jsonl (
731+ tuple (input_files ),
732+ text_key ,
733+ messages_key ,
734+ system_prompt = system_prompt ,
735+ )
675736 stats = {"success" : 0 , "total" : 0 , "sum_latency" : 0 }
676737 if batch_mode :
677738 outputs = await batch_requests_async (
0 commit comments