Skip to content

Commit f89ab37

Browse files
authored
Merge pull request #443 from mikeyobrien/main
Add thinking support for claude-3-7-sonnet
2 parents ff41479 + 10242dc commit f89ab37

File tree

1 file changed

+135
-39
lines changed

1 file changed

+135
-39
lines changed

examples/pipelines/providers/anthropic_manifold_pipeline.py

Lines changed: 135 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
license: MIT
77
description: A pipeline for generating text and processing images using the Anthropic API.
88
requirements: requests, sseclient-py
9-
environment_variables: ANTHROPIC_API_KEY
9+
environment_variables: ANTHROPIC_API_KEY, ANTHROPIC_THINKING_BUDGET_TOKENS, ANTHROPIC_ENABLE_THINKING
1010
"""
1111

1212
import os
@@ -18,6 +18,17 @@
1818

1919
from utils.pipelines.main import pop_system_message
2020

21+
REASONING_EFFORT_BUDGET_TOKEN_MAP = {
22+
"none": None,
23+
"low": 1024,
24+
"medium": 4096,
25+
"high": 16384,
26+
"max": 32768,
27+
}
28+
29+
# Maximum combined token limit for Claude 3.7
30+
MAX_COMBINED_TOKENS = 64000
31+
2132

2233
class Pipeline:
2334
class Valves(BaseModel):
@@ -29,16 +40,20 @@ def __init__(self):
2940
self.name = "anthropic/"
3041

3142
self.valves = self.Valves(
32-
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "your-api-key-here")}
43+
**{
44+
"ANTHROPIC_API_KEY": os.getenv(
45+
"ANTHROPIC_API_KEY", "your-api-key-here"
46+
),
47+
}
3348
)
34-
self.url = 'https://api.anthropic.com/v1/messages'
49+
self.url = "https://api.anthropic.com/v1/messages"
3550
self.update_headers()
3651

3752
def update_headers(self):
3853
self.headers = {
39-
'anthropic-version': '2023-06-01',
40-
'content-type': 'application/json',
41-
'x-api-key': self.valves.ANTHROPIC_API_KEY
54+
"anthropic-version": "2023-06-01",
55+
"content-type": "application/json",
56+
"x-api-key": self.valves.ANTHROPIC_API_KEY,
4257
}
4358

4459
def get_anthropic_models(self):
@@ -88,7 +103,7 @@ def pipe(
88103
) -> Union[str, Generator, Iterator]:
89104
try:
90105
# Remove unnecessary keys
91-
for key in ['user', 'chat_id', 'title']:
106+
for key in ["user", "chat_id", "title"]:
92107
body.pop(key, None)
93108

94109
system_message, messages = pop_system_message(messages)
@@ -102,28 +117,40 @@ def pipe(
102117
if isinstance(message.get("content"), list):
103118
for item in message["content"]:
104119
if item["type"] == "text":
105-
processed_content.append({"type": "text", "text": item["text"]})
120+
processed_content.append(
121+
{"type": "text", "text": item["text"]}
122+
)
106123
elif item["type"] == "image_url":
107124
if image_count >= 5:
108-
raise ValueError("Maximum of 5 images per API call exceeded")
125+
raise ValueError(
126+
"Maximum of 5 images per API call exceeded"
127+
)
109128

110129
processed_image = self.process_image(item["image_url"])
111130
processed_content.append(processed_image)
112131

113132
if processed_image["source"]["type"] == "base64":
114-
image_size = len(processed_image["source"]["data"]) * 3 / 4
133+
image_size = (
134+
len(processed_image["source"]["data"]) * 3 / 4
135+
)
115136
else:
116137
image_size = 0
117138

118139
total_image_size += image_size
119140
if total_image_size > 100 * 1024 * 1024:
120-
raise ValueError("Total size of images exceeds 100 MB limit")
141+
raise ValueError(
142+
"Total size of images exceeds 100 MB limit"
143+
)
121144

122145
image_count += 1
123146
else:
124-
processed_content = [{"type": "text", "text": message.get("content", "")}]
147+
processed_content = [
148+
{"type": "text", "text": message.get("content", "")}
149+
]
125150

126-
processed_messages.append({"role": message["role"], "content": processed_content})
151+
processed_messages.append(
152+
{"role": message["role"], "content": processed_content}
153+
)
127154

128155
# Prepare the payload
129156
payload = {
@@ -139,38 +166,107 @@ def pipe(
139166
}
140167

141168
if body.get("stream", False):
169+
supports_thinking = "claude-3-7" in model_id
170+
reasoning_effort = body.get("reasoning_effort", "none")
171+
budget_tokens = REASONING_EFFORT_BUDGET_TOKEN_MAP.get(reasoning_effort)
172+
173+
# Allow users to input an integer value representing budget tokens
174+
if (
175+
not budget_tokens
176+
and reasoning_effort not in REASONING_EFFORT_BUDGET_TOKEN_MAP.keys()
177+
):
178+
try:
179+
budget_tokens = int(reasoning_effort)
180+
except ValueError as e:
181+
print("Failed to convert reasoning effort to int", e)
182+
budget_tokens = None
183+
184+
if supports_thinking and budget_tokens:
185+
# Check if the combined tokens (budget_tokens + max_tokens) exceeds the limit
186+
max_tokens = payload.get("max_tokens", 4096)
187+
combined_tokens = budget_tokens + max_tokens
188+
189+
if combined_tokens > MAX_COMBINED_TOKENS:
190+
error_message = f"Error: Combined tokens (budget_tokens {budget_tokens} + max_tokens {max_tokens} = {combined_tokens}) exceeds the maximum limit of {MAX_COMBINED_TOKENS}"
191+
print(error_message)
192+
return error_message
193+
194+
payload["max_tokens"] = combined_tokens
195+
payload["thinking"] = {
196+
"type": "enabled",
197+
"budget_tokens": budget_tokens,
198+
}
199+
# Thinking requires temperature 1.0 and does not support top_p, top_k
200+
payload["temperature"] = 1.0
201+
if "top_k" in payload:
202+
del payload["top_k"]
203+
if "top_p" in payload:
204+
del payload["top_p"]
142205
return self.stream_response(payload)
143206
else:
144207
return self.get_completion(payload)
145208
except Exception as e:
146209
return f"Error: {e}"
147210

148211
def stream_response(self, payload: dict) -> Generator:
149-
response = requests.post(self.url, headers=self.headers, json=payload, stream=True)
150-
151-
if response.status_code == 200:
152-
client = sseclient.SSEClient(response)
153-
for event in client.events():
154-
try:
155-
data = json.loads(event.data)
156-
if data["type"] == "content_block_start":
157-
yield data["content_block"]["text"]
158-
elif data["type"] == "content_block_delta":
159-
yield data["delta"]["text"]
160-
elif data["type"] == "message_stop":
161-
break
162-
except json.JSONDecodeError:
163-
print(f"Failed to parse JSON: {event.data}")
164-
except KeyError as e:
165-
print(f"Unexpected data structure: {e}")
166-
print(f"Full data: {data}")
167-
else:
168-
raise Exception(f"Error: {response.status_code} - {response.text}")
212+
"""Used for title and tag generation"""
213+
try:
214+
response = requests.post(
215+
self.url, headers=self.headers, json=payload, stream=True
216+
)
217+
print(f"{response} for {payload}")
218+
219+
if response.status_code == 200:
220+
client = sseclient.SSEClient(response)
221+
for event in client.events():
222+
try:
223+
data = json.loads(event.data)
224+
if data["type"] == "content_block_start":
225+
if data["content_block"]["type"] == "thinking":
226+
yield "<think>"
227+
else:
228+
yield data["content_block"]["text"]
229+
elif data["type"] == "content_block_delta":
230+
if data["delta"]["type"] == "thinking_delta":
231+
yield data["delta"]["thinking"]
232+
elif data["delta"]["type"] == "signature_delta":
233+
yield "\n </think> \n\n"
234+
else:
235+
yield data["delta"]["text"]
236+
elif data["type"] == "message_stop":
237+
break
238+
except json.JSONDecodeError:
239+
print(f"Failed to parse JSON: {event.data}")
240+
yield f"Error: Failed to parse JSON response"
241+
except KeyError as e:
242+
print(f"Unexpected data structure: {e} for payload {payload}")
243+
print(f"Full data: {data}")
244+
yield f"Error: Unexpected data structure: {e}"
245+
else:
246+
error_message = f"Error: {response.status_code} - {response.text}"
247+
print(error_message)
248+
yield error_message
249+
except Exception as e:
250+
error_message = f"Error: {str(e)}"
251+
print(error_message)
252+
yield error_message
169253

170254
def get_completion(self, payload: dict) -> str:
171-
response = requests.post(self.url, headers=self.headers, json=payload)
172-
if response.status_code == 200:
173-
res = response.json()
174-
return res["content"][0]["text"] if "content" in res and res["content"] else ""
175-
else:
176-
raise Exception(f"Error: {response.status_code} - {response.text}")
255+
try:
256+
response = requests.post(self.url, headers=self.headers, json=payload)
257+
print(response, payload)
258+
if response.status_code == 200:
259+
res = response.json()
260+
for content in res["content"]:
261+
if not content.get("text"):
262+
continue
263+
return content["text"]
264+
return ""
265+
else:
266+
error_message = f"Error: {response.status_code} - {response.text}"
267+
print(error_message)
268+
return error_message
269+
except Exception as e:
270+
error_message = f"Error: {str(e)}"
271+
print(error_message)
272+
return error_message

0 commit comments

Comments
 (0)