11
11
from litellm .llms .base_llm .audio_transcription .transformation import (
12
12
BaseAudioTranscriptionConfig ,
13
13
)
14
+ from litellm .llms .base_llm .base_model_iterator import MockResponseIterator
14
15
from litellm .llms .base_llm .chat .transformation import BaseConfig
15
16
from litellm .llms .base_llm .embedding .transformation import BaseEmbeddingConfig
16
17
from litellm .llms .base_llm .files .transformation import BaseFilesConfig
@@ -231,6 +232,7 @@ def completion(
231
232
):
232
233
json_mode : bool = optional_params .pop ("json_mode" , False )
233
234
extra_body : Optional [dict ] = optional_params .pop ("extra_body" , None )
235
+ fake_stream = fake_stream or optional_params .pop ("fake_stream" , False )
234
236
235
237
provider_config = ProviderConfigManager .get_provider_chat_config (
236
238
model = model , provider = litellm .LlmProviders (custom_llm_provider )
@@ -317,6 +319,7 @@ def completion(
317
319
),
318
320
litellm_params = litellm_params ,
319
321
json_mode = json_mode ,
322
+ optional_params = optional_params ,
320
323
)
321
324
322
325
else :
@@ -378,6 +381,7 @@ def completion(
378
381
),
379
382
litellm_params = litellm_params ,
380
383
json_mode = json_mode ,
384
+ optional_params = optional_params ,
381
385
)
382
386
return CustomStreamWrapper (
383
387
completion_stream = completion_stream ,
@@ -426,6 +430,7 @@ def make_sync_call(
426
430
model : str ,
427
431
messages : list ,
428
432
logging_obj ,
433
+ optional_params : dict ,
429
434
litellm_params : dict ,
430
435
timeout : Union [float , httpx .Timeout ],
431
436
fake_stream : bool = False ,
@@ -457,11 +462,22 @@ def make_sync_call(
457
462
)
458
463
459
464
if fake_stream is True :
460
- completion_stream = provider_config .get_model_response_iterator (
461
- streaming_response = response .json (),
462
- sync_stream = True ,
465
+ model_response : (ModelResponse ) = provider_config .transform_response (
466
+ model = model ,
467
+ raw_response = response ,
468
+ model_response = litellm .ModelResponse (),
469
+ logging_obj = logging_obj ,
470
+ request_data = data ,
471
+ messages = messages ,
472
+ optional_params = optional_params ,
473
+ litellm_params = litellm_params ,
474
+ encoding = None ,
463
475
json_mode = json_mode ,
464
476
)
477
+
478
+ completion_stream : Any = MockResponseIterator (
479
+ model_response = model_response , json_mode = json_mode
480
+ )
465
481
else :
466
482
completion_stream = provider_config .get_model_response_iterator (
467
483
streaming_response = response .iter_lines (),
@@ -491,6 +507,7 @@ async def acompletion_stream_function(
491
507
logging_obj : LiteLLMLoggingObj ,
492
508
data : dict ,
493
509
litellm_params : dict ,
510
+ optional_params : dict ,
494
511
fake_stream : bool = False ,
495
512
client : Optional [AsyncHTTPHandler ] = None ,
496
513
json_mode : Optional [bool ] = None ,
@@ -509,6 +526,7 @@ async def acompletion_stream_function(
509
526
)
510
527
511
528
completion_stream , _response_headers = await self .make_async_call_stream_helper (
529
+ model = model ,
512
530
custom_llm_provider = custom_llm_provider ,
513
531
provider_config = provider_config ,
514
532
api_base = api_base ,
@@ -520,6 +538,8 @@ async def acompletion_stream_function(
520
538
fake_stream = fake_stream ,
521
539
client = client ,
522
540
litellm_params = litellm_params ,
541
+ optional_params = optional_params ,
542
+ json_mode = json_mode ,
523
543
)
524
544
streamwrapper = CustomStreamWrapper (
525
545
completion_stream = completion_stream ,
@@ -531,6 +551,7 @@ async def acompletion_stream_function(
531
551
532
552
async def make_async_call_stream_helper (
533
553
self ,
554
+ model : str ,
534
555
custom_llm_provider : str ,
535
556
provider_config : BaseConfig ,
536
557
api_base : str ,
@@ -540,8 +561,10 @@ async def make_async_call_stream_helper(
540
561
logging_obj : LiteLLMLoggingObj ,
541
562
timeout : Union [float , httpx .Timeout ],
542
563
litellm_params : dict ,
564
+ optional_params : dict ,
543
565
fake_stream : bool = False ,
544
566
client : Optional [AsyncHTTPHandler ] = None ,
567
+ json_mode : Optional [bool ] = None ,
545
568
) -> Tuple [Any , httpx .Headers ]:
546
569
"""
547
570
Helper function for making an async call with stream.
@@ -572,8 +595,21 @@ async def make_async_call_stream_helper(
572
595
)
573
596
574
597
if fake_stream is True :
575
- completion_stream = provider_config .get_model_response_iterator (
576
- streaming_response = response .json (), sync_stream = False
598
+ model_response : (ModelResponse ) = provider_config .transform_response (
599
+ model = model ,
600
+ raw_response = response ,
601
+ model_response = litellm .ModelResponse (),
602
+ logging_obj = logging_obj ,
603
+ request_data = data ,
604
+ messages = messages ,
605
+ optional_params = optional_params ,
606
+ litellm_params = litellm_params ,
607
+ encoding = None ,
608
+ json_mode = json_mode ,
609
+ )
610
+
611
+ completion_stream : Any = MockResponseIterator (
612
+ model_response = model_response , json_mode = json_mode
577
613
)
578
614
else :
579
615
completion_stream = provider_config .get_model_response_iterator (
@@ -598,8 +634,12 @@ def _add_stream_param_to_request_body(
598
634
"""
599
635
Some providers like Bedrock invoke do not support the stream parameter in the request body, we only pass `stream` in the request body the provider supports it.
600
636
"""
637
+
601
638
if fake_stream is True :
602
- return data
639
+ # remove 'stream' from data
640
+ new_data = data .copy ()
641
+ new_data .pop ("stream" , None )
642
+ return new_data
603
643
if provider_config .supports_stream_param_in_request_body is True :
604
644
data ["stream" ] = True
605
645
return data
0 commit comments