44from base64 import b64decode
55
66try :
7- from litellm import (
8- CustomLLM ,
9- ModelResponse ,
10- Choices ,
11- Message ,
12- )
13-
7+ from litellm import CustomLLM
8+ from litellm .llms .custom_httpx .http_handler import HTTPHandler
9+ from litellm .utils import ModelResponse , Choices , Message
1410except ImportError :
1511 raise ImportError (
1612 "The 'litellm' library is required to use this functionality. "
1713 "Install it with: pip install nuclia[litellm]"
1814 )
1915
20- # Nuclia (sync & async)
21- from nuclia .lib .nua import NuaClient , AsyncNuaClient
22- from nuclia .sdk .predict import NucliaPredict , AsyncNucliaPredict
16+ from nuclia .lib .nua import NuaClient
17+ from nuclia .sdk .predict import NucliaPredict
2318from nuclia .lib .nua_responses import ChatModel , UserPrompt
2419from nuclia_models .predict .generative_responses import (
2520 GenerativeFullResponse ,
2621)
22+ from typing import Callable , Optional , Union
23+
24+ import httpx
2725
2826
2927class NucliaNuaChat (CustomLLM ):
@@ -47,13 +45,6 @@ def __init__(self, token: str):
4745 )
4846 self .predict_sync = NucliaPredict ()
4947
50- self .nc_async = AsyncNuaClient (
51- region = self .region_base_url ,
52- token = self .token ,
53- account = "" , # Not needed for current implementation, required by the client
54- )
55- self .predict_async = AsyncNucliaPredict ()
56-
5748 @staticmethod
5849 def _parse_token (token : str ):
5950 parts = token .split ("." )
@@ -85,7 +76,23 @@ def _process_messages(self, messages: list[dict[str, str]]) -> tuple[str, str]:
8576 return formatted_system , formatted_user
8677
8778 def completion (
88- self , * args , model : str , messages : list [dict [str , str ]], ** kwargs
79+ self ,
80+ model : str ,
81+ messages : list ,
82+ api_base : str ,
83+ custom_prompt_dict : dict ,
84+ model_response : ModelResponse ,
85+ print_verbose : Callable ,
86+ encoding ,
87+ api_key ,
88+ logging_obj ,
89+ optional_params : dict ,
90+ acompletion = None ,
91+ litellm_params = None ,
92+ logger_fn = None ,
93+ headers = {},
94+ timeout : Optional [Union [float , httpx .Timeout ]] = None ,
95+ client : Optional [HTTPHandler ] = None ,
8996 ) -> ModelResponse :
9097 if not self .predict_sync or not self .nc_sync :
9198 raise RuntimeError ("Sync clients not initialized." )
0 commit comments