@@ -65,54 +65,69 @@ def convert_llama_instruct_text(
6565 return messages
6666
6767
68+ def _get_request (key : str , data : tp .Dict [str , tp .Any ]) -> tp .Optional [tp .Any ]:
69+ keys = key .split ("." )
70+ current_data = data
71+ for k in keys :
72+ if isinstance (current_data , dict ) and k in current_data :
73+ current_data = current_data [k ]
74+ else :
75+ return None
76+ return current_data
77+
78+
79+ def _get_metadata_key (text_key : str ) -> str :
80+ parts = text_key .split ("." )
81+ parts [- 1 ] = "metadata"
82+ return "." .join (parts )
83+
84+
85+ def _prepare_request (
86+ sample : tp .Dict [str , tp .Any ],
87+ text_key : str ,
88+ messages_key : str ,
89+ system_prompt : str ,
90+ default_metadata : tp .Dict [str , tp .Any ],
91+ ) -> tp .Dict [str , tp .Any ]:
92+ text = _get_request (text_key , sample )
93+ if text :
94+ messages = convert_llama_instruct_text (text )
95+ metadata = _get_request (_get_metadata_key (text_key ), sample )
96+ else :
97+ messages = _get_request (messages_key , sample ) # type: ignore
98+ assert messages , f"either { text_key } or { messages_key } should exist"
99+ metadata = _get_request (_get_metadata_key (messages_key ), sample )
100+
101+ if system_prompt :
102+ if messages [0 ]["role" ] == "system" :
103+ messages [0 ]["content" ] = system_prompt
104+ else :
105+ messages .insert (0 , {"role" : "system" , "content" : system_prompt })
106+
107+ if metadata is None :
108+ metadata = default_metadata
109+ return {"metadata" : metadata , "messages" : messages }
110+
111+
68112def load_from_jsonl (
69113 input_files : tp .Tuple [str , ...],
70114 text_key : str ,
71115 messages_key : str ,
72116 system_prompt : str ,
73117) -> tp .List [tp .Dict [str , tp .Any ]]:
74118
75- def get_request (key : str , data : tp .Dict [str , tp .Any ]) -> tp .Optional [tp .Any ]:
76- keys = key .split ("." )
77- current_data = data
78- for k in keys :
79- if isinstance (current_data , dict ) and k in current_data :
80- current_data = current_data [k ]
81- else :
82- return None
83- return current_data
84-
85- def get_metadata_key (text_key : str ) -> str :
86- parts = text_key .split ("." )
87- parts [- 1 ] = "metadata"
88- return "." .join (parts )
89-
90119 def load_json_line (
91- file_name : str , line : str , line_number : int , system_prompt : str
120+ file_name : str , line : str , line_number : int
92121 ) -> tp .Dict [str , tp .Any ]:
93122 try :
94123 data = json .loads (line )
95- text = get_request (text_key , data )
96- if text :
97- messages = convert_llama_instruct_text (text )
98- metadata = get_request (get_metadata_key (text_key ), data )
99- else :
100- messages = get_request (messages_key , data ) # type: ignore
101- assert messages , f"either { text_key } or { messages_key } should exist"
102- metadata = get_request (get_metadata_key (messages_key ), data )
103-
104- if system_prompt :
105- if messages [0 ]["role" ] == "system" :
106- messages [0 ]["content" ] = system_prompt
107- else :
108- messages .insert (0 , {"role" : "system" , "content" : system_prompt })
109-
110- if metadata is None :
111- metadata = {"filename" : file_name , "line" : line_number }
112- return {
113- "metadata" : metadata ,
114- "messages" : messages ,
115- }
124+ return _prepare_request (
125+ data ,
126+ text_key ,
127+ messages_key ,
128+ system_prompt ,
129+ {"filename" : file_name , "line" : line_number },
130+ )
116131 except Exception as e :
117132 raise ValueError (f"Error in line { line_number } \n { line } of { file_name } : { e } " )
118133
@@ -126,7 +141,7 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
126141 max_length = 0
127142 num_lines = 0
128143 for num_lines , line in enumerate (f , start = 1 ):
129- item = load_json_line (file_name , line , num_lines , system_prompt )
144+ item = load_json_line (file_name , line , num_lines )
130145 max_length = max (get_text_length (item ["messages" ]), max_length )
131146 # Add metadata to the dictionary
132147 data .append (item )
@@ -136,6 +151,31 @@ def get_text_length(messages: tp.List[tp.Dict[str, str]]) -> int:
136151 return data
137152
138153
154+ def load_from_hf_dataset (
155+ dataset_name : str ,
156+ split : str ,
157+ text_key : str ,
158+ messages_key : str ,
159+ system_prompt : str ,
160+ ) -> tp .List [tp .Dict [str , tp .Any ]]:
161+ from datasets import load_dataset
162+
163+ dataset = load_dataset (dataset_name , split = split )
164+ data = []
165+ for idx , sample in enumerate (dataset ):
166+ data .append (
167+ _prepare_request (
168+ sample ,
169+ text_key ,
170+ messages_key ,
171+ system_prompt ,
172+ {"index" : idx },
173+ )
174+ )
175+ logger .info (f"Loaded { len (data )} samples from { dataset_name } split { split } " )
176+ return data
177+
178+
139179def _convert_token_log_probs (token_log_probs ):
140180 if not token_log_probs .token_map :
141181 return None
@@ -617,9 +657,9 @@ def batch_requests(
617657async def main (
618658 url : tp .Union [str , tp .Callable [[], tp .Awaitable [str ]]],
619659 output_file : str ,
620- input_jsonls : str ,
621- app_name : str ,
622- model : str ,
660+ input_jsonls : str | None = None ,
661+ app_name : str = "" ,
662+ model : str = "" ,
623663 batch_size = 32 ,
624664 seed = 42 ,
625665 temperature = 0.7 ,
@@ -632,6 +672,8 @@ async def main(
632672 system_prompt = "" ,
633673 timeout_secs = 600 ,
634674 batch_mode = False ,
675+ input_hf_dataset : str | None = None ,
676+ hf_dataset_split : str = "train" ,
635677) -> tp .Dict [str , int ]:
636678 """Send jsonl llama3 instruct prompt for inference and save both the request and response as jsonl.
637679 params:
@@ -640,6 +682,8 @@ async def main(
640682 input_jsonls: variable num of input jsonl files, each line is a json with two formats
641683 1. {text_key: prompt} if text_key is found, prompt is raw text
642684 2. {messages_key: Iterable[ChatCompletionMessageParam]} if messages_key is found.
685+ input_hf_dataset: name of a Hugging Face dataset to load directly.
686+ hf_dataset_split: dataset split to use when loading from Hugging Face.
643687 model: the huggingface model name or a directory.
644688 batch_size: max number of concurrent requests.
645689 seed: seed.
@@ -661,17 +705,25 @@ async def main(
661705 os .makedirs (save_dir , exist_ok = True )
662706 if os .path .exists (output_file ):
663707 logger .warning (f"Output file '{ output_file } ' already exists, overwriting..." )
664- input_files = glob .glob (input_jsonls )
665- if not input_files :
666- logger .error (f"No input files found matching pattern: { input_jsonls } " )
667- return {}
668-
669- lines = load_from_jsonl (
670- tuple (input_files ),
671- text_key ,
672- messages_key ,
673- system_prompt = system_prompt ,
674- )
708+ if input_hf_dataset :
709+ lines = load_from_hf_dataset (
710+ input_hf_dataset ,
711+ hf_dataset_split ,
712+ text_key ,
713+ messages_key ,
714+ system_prompt = system_prompt ,
715+ )
716+ else :
717+ input_files = glob .glob (input_jsonls or "" )
718+ if not input_files :
719+ logger .error (f"No input files found matching pattern: { input_jsonls } " )
720+ return {}
721+ lines = load_from_jsonl (
722+ tuple (input_files ),
723+ text_key ,
724+ messages_key ,
725+ system_prompt = system_prompt ,
726+ )
675727 stats = {"success" : 0 , "total" : 0 , "sum_latency" : 0 }
676728 if batch_mode :
677729 outputs = await batch_requests_async (
0 commit comments