@@ -61,7 +61,7 @@ def _is_generation_cancelled(self, gen_id):
6161
6262 def _generate_single (
6363 self ,
64- prompt : str ,
64+ prompt : str | list ,
6565 code_begin : str ,
6666 code_end : str ,
6767 code_output_begin : str ,
@@ -81,8 +81,9 @@ def _generate_single(
8181 max_code_executions : int | None = None , # if not None, will override self.config.max_code_executions
8282 stream : bool = False ,
8383 ):
84- if not isinstance (prompt , str ):
85- raise NotImplementedError ("OpenAI API is not supported yet." )
84+ # Handle OpenAI-style dictionary prompts
85+ is_openai_format = not isinstance (prompt , str )
86+
8687 if top_logprobs is not None : # TODO: add this
8788 raise NotImplementedError ("top_logprobs is not supported yet." )
8889
@@ -106,18 +107,20 @@ def _generate_single(
106107 max_code_executions = max_code_executions ,
107108 )
108109
109- if stop_phrases is None :
110- stop_phrases = []
111-
112110 effective_max_code_executions = self .config .max_code_executions
113111 if max_code_executions is not None :
114112 effective_max_code_executions = max_code_executions
115113
116114 # making a copy of prompts to not corrupt original data
117- new_prompt = copy .deepcopy (prompt )
115+ if is_openai_format :
116+ new_prompt = copy .deepcopy (prompt )
117+ else :
118+ new_prompt = copy .deepcopy (prompt )
118119
119120 start_time = int (time .time ())
120121
122+ stop_phrases = stop_phrases or []
123+
121124 request = {
122125 "prompt" : new_prompt ,
123126 "tokens_to_generate" : tokens_to_generate ,
@@ -176,7 +179,19 @@ def _generate_single(
176179 output , num_generated_tokens = output_dict ['generation' ], output_dict .get ('num_generated_tokens' , 0 )
177180 # no need to do anything with this as the code below should just exit, so that's only for logging
178181 stopped_on_repetition = output_dict .get ('stopped_on_repetition' , False )
179- request ['prompt' ] += output
182+
183+ # openai don't show what stop word was triggered, so we assume that it was `code_end`
184+ # if there's an unfinished code block
185+ if is_openai_format and output_dict .get ('finish_reason' ) == 'stop' :
186+ if output .count (code_end ) + 1 == output .count (code_begin ):
187+ output += code_end
188+ # Update the prompt based on format
189+ if is_openai_format :
190+ request ['prompt' ].append ({'role' : 'assistant' , 'content' : output })
191+ request ['prompt' ].append ({'role' : 'user' , 'content' : "continue" })
192+ else :
193+ request ['prompt' ] += output
194+
180195 # if it's the extra iteration, we don't execute the code block and just finish
181196
182197 if generation_index == effective_max_code_executions :
@@ -204,17 +219,28 @@ def _generate_single(
204219 if self .config .add_remaining_code_executions :
205220 remaining_code_executions = effective_max_code_executions - generation_index - 1
206221 # adding code output to the prompt
207- request [ 'prompt' ] + = format_code_output (
222+ code_output = format_code_output (
208223 execution_dict , code_output_begin , code_output_end , code_output_format , remaining_code_executions
209224 )
225+
226+ if is_openai_format :
227+ request ['prompt' ][- 2 ]['content' ] += code_output
228+ else :
229+ request ['prompt' ] += code_output
230+
210231 code_execution_time += int (time .time () - code_execution_time_start )
211232 code_rounds_executed += 1
212233 else : # if no code was generated, we need to finish
213234 break
214235
215- # removing original prompt
236+ # removing original prompt and returning the generation
237+ if is_openai_format :
238+ generation = "\n " .join (msg ['content' ] for msg in request ['prompt' ] if msg ['role' ] == 'assistant' )
239+ else :
240+ generation = request ['prompt' ][len (prompt ):]
241+
216242 return {
217- 'generation' : request [ 'prompt' ][ len ( prompt ) :] ,
243+ 'generation' : generation ,
218244 'code_rounds_executed' : code_rounds_executed ,
219245 'num_generated_tokens' : total_num_generated_tokens ,
220246 'generation_time' : generation_time ,
@@ -433,6 +459,9 @@ def _stream_single(
433459 """
434460 Helper method, that implements streaming generation.
435461 """
462+ # Handle OpenAI-style dictionary prompts
463+ is_openai_format = not isinstance (prompt , str )
464+
436465 effective_max_code_executions = self .config .max_code_executions
437466 if max_code_executions is not None :
438467 effective_max_code_executions = max_code_executions
@@ -452,7 +481,7 @@ def _stream_single(
452481 'stream' : True ,
453482 }
454483
455- current_full_prompt = prompt
484+ current_full_prompt = copy . deepcopy ( prompt )
456485 session_id = None # For sandbox state continuity
457486 for generation_index in range (effective_max_code_executions + 1 ):
458487 model_token_iterator = self .model ._generate_single (prompt = current_full_prompt , ** request )
@@ -470,7 +499,18 @@ def _stream_single(
470499 if not current_output_segment :
471500 break
472501
473- current_full_prompt += current_output_segment
502+ # openai don't show what stop word was triggered, so we assume that it was `code_end`
503+ # if there's an unfinished code block
504+ if is_openai_format and chunk .get ('finish_reason' ) == 'stop' :
505+ if current_output_segment .count (code_end ) + 1 == current_output_segment .count (code_begin ):
506+ current_output_segment += code_end
507+
508+ # Update the prompt based on format
509+ if is_openai_format :
510+ current_full_prompt .append ({'role' : 'assistant' , 'content' : current_output_segment })
511+ current_full_prompt .append ({'role' : 'user' , 'content' : "continue" })
512+ else :
513+ current_full_prompt += current_output_segment
474514
475515 if generation_index == effective_max_code_executions :
476516 # This was the last iteration, intended for final text generation after all code executions.
@@ -496,7 +536,12 @@ def _stream_single(
496536 )
497537
498538 yield {'generation' : formatted_code_output } # Yield the entire formatted code output as one chunk
499- current_full_prompt += formatted_code_output # Append executed code's output to the prompt
539+
540+ # Append executed code's output to the prompt
541+ if is_openai_format :
542+ current_full_prompt [- 2 ]['content' ] += formatted_code_output
543+ else :
544+ current_full_prompt += formatted_code_output
500545 else :
501546 break
502547
0 commit comments