Skip to content

Commit b6fe2cb

Browse files
committed
more small changes
1 parent 1355db2 commit b6fe2cb

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed
Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,24 @@
44
from base64 import b64decode
55

66
try:
7-
from litellm import (
8-
CustomLLM,
9-
ModelResponse,
10-
Choices,
11-
Message,
12-
)
13-
7+
from litellm import CustomLLM
8+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
9+
from litellm.utils import ModelResponse, Choices, Message
1410
except ImportError:
1511
raise ImportError(
1612
"The 'litellm' library is required to use this functionality. "
1713
"Install it with: pip install nuclia[litellm]"
1814
)
1915

20-
# Nuclia (sync & async)
21-
from nuclia.lib.nua import NuaClient, AsyncNuaClient
22-
from nuclia.sdk.predict import NucliaPredict, AsyncNucliaPredict
16+
from nuclia.lib.nua import NuaClient
17+
from nuclia.sdk.predict import NucliaPredict
2318
from nuclia.lib.nua_responses import ChatModel, UserPrompt
2419
from nuclia_models.predict.generative_responses import (
2520
GenerativeFullResponse,
2621
)
22+
from typing import Callable, Optional, Union
23+
24+
import httpx
2725

2826

2927
class NucliaNuaChat(CustomLLM):
@@ -47,13 +45,6 @@ def __init__(self, token: str):
4745
)
4846
self.predict_sync = NucliaPredict()
4947

50-
self.nc_async = AsyncNuaClient(
51-
region=self.region_base_url,
52-
token=self.token,
53-
account="", # Not needed for current implementation, required by the client
54-
)
55-
self.predict_async = AsyncNucliaPredict()
56-
5748
@staticmethod
5849
def _parse_token(token: str):
5950
parts = token.split(".")
@@ -85,7 +76,23 @@ def _process_messages(self, messages: list[dict[str, str]]) -> tuple[str, str]:
8576
return formatted_system, formatted_user
8677

8778
def completion(
88-
self, *args, model: str, messages: list[dict[str, str]], **kwargs
79+
self,
80+
model: str,
81+
messages: list,
82+
api_base: str,
83+
custom_prompt_dict: dict,
84+
model_response: ModelResponse,
85+
print_verbose: Callable,
86+
encoding,
87+
api_key,
88+
logging_obj,
89+
optional_params: dict,
90+
acompletion=None,
91+
litellm_params=None,
92+
logger_fn=None,
93+
headers={},
94+
timeout: Optional[Union[float, httpx.Timeout]] = None,
95+
client: Optional[HTTPHandler] = None,
8996
) -> ModelResponse:
9097
if not self.predict_sync or not self.nc_sync:
9198
raise RuntimeError("Sync clients not initialized.")

0 commit comments

Comments
 (0)