@@ -148,8 +148,142 @@ def extract_cls_task_type(
148148 def extract_json (
149149 completion : Iterable [Any ], mode : Mode
150150 ) -> Generator [str , None , None ]:
151+ json_started = False
151152 for chunk in completion :
152153 try :
154+ if mode in {Mode .COHERE_TOOLS , Mode .COHERE_JSON_SCHEMA }:
155+ event_type = getattr (chunk , "event_type" , None )
156+ if event_type == "text-generation" :
157+ if text := getattr (chunk , "text" , None ):
158+ if not json_started :
159+ json_start = min (
160+ (
161+ pos
162+ for pos in (text .find ("{" ), text .find ("[" ))
163+ if pos != - 1
164+ ),
165+ default = - 1 ,
166+ )
167+ if json_start == - 1 :
168+ continue
169+ json_started = True
170+ text = text [json_start :]
171+ yield text
172+ elif event_type == "tool-calls-chunk" :
173+ delta = getattr (chunk , "tool_call_delta" , None )
174+ args = getattr (delta , "parameters" , None ) or getattr (
175+ delta , "text" , None
176+ )
177+ if args :
178+ if not json_started :
179+ json_start = min (
180+ (
181+ pos
182+ for pos in (args .find ("{" ), args .find ("[" ))
183+ if pos != - 1
184+ ),
185+ default = - 1 ,
186+ )
187+ if json_start == - 1 :
188+ continue
189+ json_started = True
190+ args = args [json_start :]
191+ yield args
192+ elif text := getattr (chunk , "text" , None ):
193+ if not json_started :
194+ json_start = min (
195+ (
196+ pos
197+ for pos in (text .find ("{" ), text .find ("[" ))
198+ if pos != - 1
199+ ),
200+ default = - 1 ,
201+ )
202+ if json_start == - 1 :
203+ continue
204+ json_started = True
205+ text = text [json_start :]
206+ yield text
207+ elif event_type == "tool-calls-generation" :
208+ tool_calls = getattr (chunk , "tool_calls" , None )
209+ if tool_calls :
210+ args = json .dumps (tool_calls [0 ].parameters )
211+ if not json_started :
212+ json_start = min (
213+ (
214+ pos
215+ for pos in (args .find ("{" ), args .find ("[" ))
216+ if pos != - 1
217+ ),
218+ default = - 1 ,
219+ )
220+ if json_start == - 1 :
221+ continue
222+ json_started = True
223+ args = args [json_start :]
224+ yield args
225+ elif text := getattr (chunk , "text" , None ):
226+ if not json_started :
227+ json_start = min (
228+ (
229+ pos
230+ for pos in (text .find ("{" ), text .find ("[" ))
231+ if pos != - 1
232+ ),
233+ default = - 1 ,
234+ )
235+ if json_start == - 1 :
236+ continue
237+ json_started = True
238+ text = text [json_start :]
239+ yield text
240+ else :
241+ chunk_type = getattr (chunk , "type" , None )
242+ if chunk_type == "content-delta" :
243+ delta = getattr (chunk , "delta" , None )
244+ message = getattr (delta , "message" , None )
245+ content = getattr (message , "content" , None )
246+ if text := getattr (content , "text" , None ):
247+ if not json_started :
248+ json_start = min (
249+ (
250+ pos
251+ for pos in (
252+ text .find ("{" ),
253+ text .find ("[" ),
254+ )
255+ if pos != - 1
256+ ),
257+ default = - 1 ,
258+ )
259+ if json_start == - 1 :
260+ continue
261+ json_started = True
262+ text = text [json_start :]
263+ yield text
264+ elif chunk_type == "tool-call-delta" :
265+ delta = getattr (chunk , "delta" , None )
266+ message = getattr (delta , "message" , None )
267+ tool_calls = getattr (message , "tool_calls" , None )
268+ function = getattr (tool_calls , "function" , None )
269+ if args := getattr (function , "arguments" , None ):
270+ if not json_started :
271+ json_start = min (
272+ (
273+ pos
274+ for pos in (
275+ args .find ("{" ),
276+ args .find ("[" ),
277+ )
278+ if pos != - 1
279+ ),
280+ default = - 1 ,
281+ )
282+ if json_start == - 1 :
283+ continue
284+ json_started = True
285+ args = args [json_start :]
286+ yield args
153287 if mode == Mode .ANTHROPIC_JSON :
154288 if json_chunk := chunk .delta .text :
155289 yield json_chunk
@@ -230,8 +364,142 @@ def extract_json(
230364 async def extract_json_async (
231365 completion : AsyncGenerator [Any , None ], mode : Mode
232366 ) -> AsyncGenerator [str , None ]:
367+ json_started = False
233368 async for chunk in completion :
234369 try :
370+ if mode in {Mode .COHERE_TOOLS , Mode .COHERE_JSON_SCHEMA }:
371+ event_type = getattr (chunk , "event_type" , None )
372+ if event_type == "text-generation" :
373+ if text := getattr (chunk , "text" , None ):
374+ if not json_started :
375+ json_start = min (
376+ (
377+ pos
378+ for pos in (text .find ("{" ), text .find ("[" ))
379+ if pos != - 1
380+ ),
381+ default = - 1 ,
382+ )
383+ if json_start == - 1 :
384+ continue
385+ json_started = True
386+ text = text [json_start :]
387+ yield text
388+ elif event_type == "tool-calls-chunk" :
389+ delta = getattr (chunk , "tool_call_delta" , None )
390+ args = getattr (delta , "parameters" , None ) or getattr (
391+ delta , "text" , None
392+ )
393+ if args :
394+ if not json_started :
395+ json_start = min (
396+ (
397+ pos
398+ for pos in (args .find ("{" ), args .find ("[" ))
399+ if pos != - 1
400+ ),
401+ default = - 1 ,
402+ )
403+ if json_start == - 1 :
404+ continue
405+ json_started = True
406+ args = args [json_start :]
407+ yield args
408+ elif text := getattr (chunk , "text" , None ):
409+ if not json_started :
410+ json_start = min (
411+ (
412+ pos
413+ for pos in (text .find ("{" ), text .find ("[" ))
414+ if pos != - 1
415+ ),
416+ default = - 1 ,
417+ )
418+ if json_start == - 1 :
419+ continue
420+ json_started = True
421+ text = text [json_start :]
422+ yield text
423+ elif event_type == "tool-calls-generation" :
424+ tool_calls = getattr (chunk , "tool_calls" , None )
425+ if tool_calls :
426+ args = json .dumps (tool_calls [0 ].parameters )
427+ if not json_started :
428+ json_start = min (
429+ (
430+ pos
431+ for pos in (args .find ("{" ), args .find ("[" ))
432+ if pos != - 1
433+ ),
434+ default = - 1 ,
435+ )
436+ if json_start == - 1 :
437+ continue
438+ json_started = True
439+ args = args [json_start :]
440+ yield args
441+ elif text := getattr (chunk , "text" , None ):
442+ if not json_started :
443+ json_start = min (
444+ (
445+ pos
446+ for pos in (text .find ("{" ), text .find ("[" ))
447+ if pos != - 1
448+ ),
449+ default = - 1 ,
450+ )
451+ if json_start == - 1 :
452+ continue
453+ json_started = True
454+ text = text [json_start :]
455+ yield text
456+ else :
457+ chunk_type = getattr (chunk , "type" , None )
458+ if chunk_type == "content-delta" :
459+ delta = getattr (chunk , "delta" , None )
460+ message = getattr (delta , "message" , None )
461+ content = getattr (message , "content" , None )
462+ if text := getattr (content , "text" , None ):
463+ if not json_started :
464+ json_start = min (
465+ (
466+ pos
467+ for pos in (
468+ text .find ("{" ),
469+ text .find ("[" ),
470+ )
471+ if pos != - 1
472+ ),
473+ default = - 1 ,
474+ )
475+ if json_start == - 1 :
476+ continue
477+ json_started = True
478+ text = text [json_start :]
479+ yield text
480+ elif chunk_type == "tool-call-delta" :
481+ delta = getattr (chunk , "delta" , None )
482+ message = getattr (delta , "message" , None )
483+ tool_calls = getattr (message , "tool_calls" , None )
484+ function = getattr (tool_calls , "function" , None )
485+ if args := getattr (function , "arguments" , None ):
486+ if not json_started :
487+ json_start = min (
488+ (
489+ pos
490+ for pos in (
491+ args .find ("{" ),
492+ args .find ("[" ),
493+ )
494+ if pos != - 1
495+ ),
496+ default = - 1 ,
497+ )
498+ if json_start == - 1 :
499+ continue
500+ json_started = True
501+ args = args [json_start :]
502+ yield args
235503 if mode == Mode .ANTHROPIC_JSON :
236504 if json_chunk := chunk .delta .text :
237505 yield json_chunk
0 commit comments