Skip to content

Commit 5c790b1

Browse files
author
Safoora Yousefi
committed
bug fixes
1 parent 03c848c commit 5c790b1

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

eureka_ml_insights/models/models.py

+40-23
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import anthropic
1111
import tiktoken
12-
from azure.identity import get_bearer_token_provider
13-
from azure.identity import DefaultAzureCredential
12+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
13+
1414
from eureka_ml_insights.secret_management import get_secret
1515

1616

@@ -288,7 +288,7 @@ def get_response(self, request):
288288
"response_time": response_time,
289289
}
290290
if "usage" in res:
291-
return response_dict.update({"usage": res["usage"]})
291+
response_dict.update({"usage": res["usage"]})
292292
return response_dict
293293

294294
def handle_request_error(self, e):
@@ -395,6 +395,7 @@ def create_request(self, text_prompt, query_images=None, system_message=None, pr
395395
body = str.encode(json.dumps(data))
396396
return urllib.request.Request(self.url, body, self.headers)
397397

398+
398399
@dataclass
399400
class DeepseekR1ServerlessAzureRestEndpointModel(ServerlessAzureRestEndpointModel):
400401
# setting temperature to 0.6 as suggested in https://huggingface.co/deepseek-ai/DeepSeek-R1
@@ -410,7 +411,9 @@ def create_request(self, text_prompt, query_images=None, system_message=None, pr
410411
if previous_messages:
411412
messages.extend(previous_messages)
412413
if query_images:
413-
raise NotImplementedError("Images are not supported for DeepseekR1ServerlessAzureRestEndpointModel endpoints.")
414+
raise NotImplementedError(
415+
"Images are not supported for DeepseekR1ServerlessAzureRestEndpointModel endpoints."
416+
)
414417
messages.append({"role": "user", "content": text_prompt})
415418
data = {
416419
"messages": messages,
@@ -422,6 +425,7 @@ def create_request(self, text_prompt, query_images=None, system_message=None, pr
422425
body = str.encode(json.dumps(data))
423426
return urllib.request.Request(self.url, body, self.headers)
424427

428+
425429
@dataclass
426430
class OpenAICommonRequestResponseMixIn:
427431
"""
@@ -470,7 +474,7 @@ def get_response(self, request):
470474
"response_time": response_time,
471475
}
472476
if "usage" in openai_response:
473-
return response_dict.update({"usage": openai_response["usage"]})
477+
response_dict.update({"usage": openai_response["usage"]})
474478
return response_dict
475479

476480

@@ -489,7 +493,7 @@ def get_client(self):
489493

490494
def handle_request_error(self, e):
491495
# if the error is due to a content filter, there is no need to retry
492-
if hasattr(e, 'code') and e.code == "content_filter":
496+
if hasattr(e, "code") and e.code == "content_filter":
493497
logging.warning("Content filtered.")
494498
response = None
495499
return response, False, True
@@ -617,7 +621,7 @@ def get_response(self, request):
617621
"response_time": response_time,
618622
}
619623
if "usage" in openai_response:
620-
return response_dict.update({"usage": openai_response["usage"]})
624+
response_dict.update({"usage": openai_response["usage"]})
621625
return response_dict
622626

623627

@@ -706,6 +710,7 @@ def create_request(self, text_prompt, query_images=None, system_message=None, pr
706710

707711
def get_response(self, request):
708712
start_time = time.time()
713+
gemini_response = None
709714
try:
710715
gemini_response = self.model.generate_content(
711716
request,
@@ -717,9 +722,7 @@ def get_response(self, request):
717722
model_output = gemini_response.parts[0].text
718723
response_time = end_time - start_time
719724
except Exception as e:
720-
is_non_transient_issue = self.handle_gemini_error(e, gemini_response)
721-
if not is_non_transient_issue:
722-
raise e
725+
self.handle_gemini_error(e, gemini_response)
723726

724727
response_dict = {
725728
"model_output": model_output,
@@ -755,7 +758,7 @@ def handle_gemini_error(self, e, gemini_response):
755758
logging.warning(
756759
f"Attempt failed due to explicitly blocked input prompt: {e} Block Reason {gemini_response.prompt_feedback.block_reason}"
757760
)
758-
return True
761+
759762
# Handling cases where the model implicitly blocks prompts and does not provide an explicit block reason for it but rather an empty content.
760763
# In these cases, there is no need to make a new attempt as the model will continue to implicitly block the request, do_return = True.
761764
# Note that, in some cases, the model may still provide a finish reason as shown here https://ai.google.dev/api/generate-content?authuser=2#FinishReason
@@ -771,11 +774,11 @@ def handle_gemini_error(self, e, gemini_response):
771774
logging.warning(
772775
f"Safety Ratings for the first answer candidate are: {gemini_response.candidates[0].safety_ratings}"
773776
)
774-
return True
775-
# Any other case will be re attempted again, do_return = False.
776-
return False
777+
778+
raise e
777779

778780
def handle_request_error(self, e):
781+
# Any error case not handled in handle_gemini_error will be attempted again, do_return = False.
779782
return False
780783

781784

@@ -1326,19 +1329,25 @@ def get_response(self, request):
13261329
def handle_request_error(self, e):
13271330
return False
13281331

1332+
13291333
@dataclass
13301334
class ClaudeReasoningModel(ClaudeModel):
13311335
"""This class is used to interact with Claude reasoning models through the python api."""
13321336

13331337
model_name: str = None
1334-
temperature: float = 1.
1338+
temperature: float = 1.0
13351339
max_tokens: int = 20000
13361340
timeout: int = 600
13371341
thinking_enabled: bool = True
13381342
thinking_budget: int = 16000
13391343
top_p: float = None
13401344

13411345
def get_response(self, request):
1346+
model_output = None
1347+
response_time = None
1348+
thinking_output = None
1349+
redacted_thinking_output = None
1350+
response_dict = {}
13421351
if self.top_p is not None:
13431352
logging.warning("top_p is not supported for claude reasoning models as of 03/08/2025. It will be ignored.")
13441353

@@ -1355,16 +1364,24 @@ def get_response(self, request):
13551364

13561365
# Loop through completion.content to find the text output
13571366
for content in completion.content:
1358-
if content.type == 'text':
1359-
self.model_output = content.text
1360-
elif content.type == 'thinking':
1361-
self.thinking_output = content.thinking
1362-
elif content.type == 'redacted_thinking':
1363-
self.redacted_thinking_output = content.data
1367+
if content.type == "text":
1368+
model_output = content.text
1369+
elif content.type == "thinking":
1370+
thinking_output = content.thinking
1371+
elif content.type == "redacted_thinking":
1372+
redacted_thinking_output = content.data
13641373

1365-
self.response_time = end_time - start_time
1374+
response_time = end_time - start_time
1375+
response_dict = {
1376+
"model_output": model_output,
1377+
"response_time": response_time,
1378+
"thinking_output": thinking_output,
1379+
"redacted_thinking_output": redacted_thinking_output,
1380+
}
13661381
if hasattr(completion, "usage"):
1367-
return {"usage": completion.usage.to_dict()}
1382+
response_dict.update({"usage": completion.usage.to_dict()})
1383+
return response_dict
1384+
13681385

13691386
@dataclass
13701387
class TestModel(Model):

0 commit comments

Comments
 (0)