Skip to content

Commit e49da15

Browse files
committed
feactor: Refactor generate script and adjust response
1 parent 44542d5 commit e49da15

File tree

2 files changed

+128
-127
lines changed

2 files changed

+128
-127
lines changed

src/viur/assistant/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import logging
12
import typing as t
23

34
from viur.core.config import ConfigType
45

6+
ASSISTANT_LOGGER: logging.Logger = logging.getLogger("viur.assistant")
7+
58

69
class AssistantConfig(ConfigType):
710
"""Configuration for viur-assistant plugin."""
@@ -19,4 +22,3 @@ class AssistantConfig(ConfigType):
1922
"""
2023
The viur-assistent config instance.
2124
"""
22-

src/viur/assistant/modules/assistant.py

Lines changed: 125 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -7,119 +7,119 @@
77
import PIL
88
import anthropic
99
import openai
10-
from viur.core import conf, db, errors, exposed
10+
from viur.core import conf, current, errors, exposed
1111
from viur.core.decorators import access
1212
from viur.core.prototypes import List, Singleton, Tree
1313

14+
from ..config import ASSISTANT_LOGGER, CONFIG
15+
16+
logger = ASSISTANT_LOGGER.getChild(__name__)
17+
1418

1519
class Assistant(Singleton):
1620
@exposed
1721
@access("user-view")
18-
def generate_script(self,
19-
prompt: str,
20-
modules_to_include: list[str] = None,
21-
enable_caching: bool = False,
22-
max_thinking_tokens: int = 0
23-
):
24-
25-
skel = self.viewSkel()
26-
key = db.Key(skel.kindName, self.getKey())
27-
28-
if not skel.read(key):
29-
raise errors.NotFound()
22+
# @force_post
23+
def generate_script(
24+
self,
25+
prompt: str,
26+
modules_to_include: list[str] = None,
27+
enable_caching: bool = False,
28+
max_thinking_tokens: int = 0
29+
):
30+
if not (skel := self.getContents()):
31+
raise errors.InternalServerError(descr="Configuration missing")
3032

3133
llm_params = {
32-
"model": skel['anthropic_model'],
33-
"max_tokens": skel['anthropic_max_tokens'],
34-
"temperature": skel['anthropic_temperature'],
34+
"model": skel["anthropic_model"],
35+
"max_tokens": skel["anthropic_max_tokens"],
36+
"temperature": skel["anthropic_temperature"],
3537
"system": [{
3638
"type": "text",
37-
"text": skel['anthropic_system_prompt']
39+
"text": skel["anthropic_system_prompt"]
3840
}],
3941
"messages": [{
4042
"role": "user",
41-
"content": []
43+
"content": (user_content := []),
4244
}]
4345
}
4446

45-
'''
4647
# add docs to system prompt (with or without caching), should be delivered by scriptor package
4748
scriptor_doc_system_param = {
4849
"type": "text",
49-
"text": scriptor_docs_txt_data,
50-
}
51-
'''
52-
scriptor_doc_system_param = {
53-
"type": "text",
54-
"text": "" # scriptor docs#todo
50+
"text": "", # TODO: scriptor docs
5551
}
52+
# TODO: llm_params["system"].append(scriptor_doc_system_param)
5653
if enable_caching:
5754
scriptor_doc_system_param["cache_control"] = {"type": "ephemeral"}
58-
llm_params["system"].append(scriptor_doc_system_param)
5955

6056
# thinking configuration
6157
if max_thinking_tokens > 0:
6258
llm_params["thinking"] = {
6359
"type": "enabled",
64-
"budget_tokens": skel['anthropic_max_thinking_tokens']
60+
"budget_tokens": skel["anthropic_max_thinking_tokens"]
6561
}
6662

6763
# add module structures
68-
if modules_to_include is not None:
69-
structures_from_viur = {}
70-
for module_name in modules_to_include:
71-
module = getattr(conf.main_app.vi, module_name, None)
72-
if not module:
73-
continue
74-
75-
if isinstance(module, List):
76-
if module_name not in structures_from_viur:
77-
structures_from_viur[module_name] = module.structure()
78-
elif isinstance(module, Tree):
79-
if module_name not in structures_from_viur:
80-
structures_from_viur[module_name] = {
81-
"node": module.structure(skelType="node"),
82-
"leaf": module.structure(skelType="leaf")
83-
}
84-
else:
85-
raise ValueError(
86-
f"""The module should should be a instance of "tree" or "list". "{module}" is unsupported.""")
87-
88-
selected_module_structures = {
89-
"module_structures": structures_from_viur}
90-
selected_module_structures_description = json.dumps(
91-
selected_module_structures, indent=2)
92-
93-
if selected_module_structures["module_structures"]:
94-
llm_params["messages"][0]["content"].append({
95-
"type": "text",
96-
"text": selected_module_structures_description
97-
})
64+
if modules_to_include is not None and (structures := self.get_viur_structures(modules_to_include)):
65+
user_content.append({
66+
"type": "text",
67+
"text": json.dumps({
68+
"module_structures": structures
69+
}, indent=2)
70+
})
9871

99-
# finally append user prompt
100-
llm_params["messages"][0]["content"].append({
72+
# finally, append user prompt
73+
user_content.append({
10174
"type": "text",
10275
"text": prompt
10376
})
10477

105-
anthropic_client = anthropic.Anthropic(api_key=skel['anthropic_api_key'])
106-
message = anthropic_client.messages.create(**llm_params)
78+
anthropic_client = anthropic.Anthropic(api_key=CONFIG.api_anthropic_key)
79+
logger.debug(f"{llm_params=}")
80+
try:
81+
message = anthropic_client.messages.create(**llm_params)
82+
except Exception as e:
83+
logger.exception(e)
84+
logger.debug(f"{message=}")
85+
current.request.get().response.headers["Content-Type"] = "application/json"
86+
return message.model_dump_json() # TODO: parse real "code" value
10787
return message
10888

89+
def get_viur_structures(self, modules_to_include):
90+
structures_from_viur = {}
91+
for module_name in modules_to_include:
92+
module = getattr(conf.main_app.vi, module_name, None)
93+
if not module:
94+
continue
95+
96+
if isinstance(module, List):
97+
if module_name not in structures_from_viur:
98+
structures_from_viur[module_name] = module.structure()
99+
elif isinstance(module, Tree):
100+
if module_name not in structures_from_viur:
101+
structures_from_viur[module_name] = {
102+
"node": module.structure(skelType="node"),
103+
"leaf": module.structure(skelType="leaf")
104+
}
105+
else:
106+
raise ValueError(
107+
f"""The module should should be a instance of "tree" or "list". "{module}" is unsupported.""")
108+
return structures_from_viur
109+
109110
@exposed
110111
@access("user-view")
111-
def translate(self,
112-
text: str,
113-
language: str,
114-
simplified: bool = False
115-
):
116-
skel = self.viewSkel()
117-
key = db.Key(skel.kindName, self.getKey())
118-
119-
if not skel.read(key):
120-
raise errors.NotFound()
121-
logging.error(skel['openai_api_key'])
122-
openai.api_key = skel['openai_api_key']
112+
# @force_post
113+
def translate(
114+
self,
115+
text: str,
116+
language: str,
117+
simplified: bool = False
118+
):
119+
if not (skel := self.getContents()):
120+
raise errors.InternalServerError(descr="Configuration missing")
121+
122+
openai.api_key = CONFIG.api_openai_key
123123

124124
language_options = {
125125
"en": "english",
@@ -142,10 +142,10 @@ def translate(self,
142142

143143
lang_param = language_options[language]
144144
if simplified:
145-
lang_param = f"""{lang_param} ({'. '.join(simplified_language_suffixes)})"""
145+
lang_param = f"""{lang_param} ({". ".join(simplified_language_suffixes)})"""
146146

147147
response = openai.chat.completions.create(
148-
model=skel['openai_model'],
148+
model=skel["openai_model"],
149149
messages=[
150150
{
151151
"role": "user",
@@ -159,11 +159,15 @@ def translate(self,
159159

160160
@exposed
161161
@access("user-view")
162-
def describe_image(self,
163-
filekey: str,
164-
prompt: str = "",
165-
context: str = "",
166-
language: str = "de"):
162+
def describe_image(
163+
self,
164+
filekey: str,
165+
prompt: str = "",
166+
context: str = "",
167+
language: str = "de",
168+
):
169+
if not (skel := self.getContents()):
170+
raise errors.InternalServerError(descr="Configuration missing")
167171

168172
language_options = {
169173
"en": "english",
@@ -173,57 +177,15 @@ def describe_image(self,
173177
}
174178
lang_param = language_options[language]
175179

176-
def get_resized_image_bytes(image, target_pixel_count=100_000,
177-
jpeg_quality=50):
178-
# assert 0 <= jpeg_quality <= 100, "jpeg_quality must be between 0 and 100"
179-
if isinstance(image, bytes):
180-
image = io.BytesIO(image)
181-
182-
if not isinstance(image, (io.TextIOBase, io.BufferedIOBase, io.RawIOBase, io.IOBase)):
183-
raise ValueError("image must be file-like or bytes")
184-
pillow_image = PIL.Image.open(image)
185-
if pillow_image.format in ['PNG', 'SVG', 'WEBP']:
186-
jpeg_image = io.BytesIO()
187-
pillow_image.convert('RGB').save(jpeg_image, 'JPEG')
188-
jpeg_image.seek(0)
189-
pillow_image = PIL.Image.open(jpeg_image)
190-
191-
original_img_total_pixels = pillow_image.width * pillow_image.height
192-
side_ratio_to_n_pixels = (target_pixel_count / original_img_total_pixels) ** 0.5
193-
new_width = round(pillow_image.width * side_ratio_to_n_pixels)
194-
new_height = round(pillow_image.height * side_ratio_to_n_pixels)
195-
196-
if new_height > pillow_image.height or new_width > pillow_image.width:
197-
resized_img = pillow_image
198-
else:
199-
resized_img = pillow_image.resize(
200-
(new_width, new_height),
201-
PIL.Image.Resampling.LANCZOS
202-
)
203-
204-
result_bio = io.BytesIO()
205-
resized_img.save(result_bio, "jpeg", quality=jpeg_quality)
206-
result_bio.seek(0)
207-
return result_bio.read()
208-
209-
if not openai:
210-
raise errors.BadGateway("Needed Dependencies are missing.")
211-
212-
skel = self.viewSkel()
213-
key = db.Key(skel.kindName, self.getKey())
214-
215-
if not skel.read(key):
216-
raise errors.NotFound()
217-
218-
openai.api_key = skel['openai_api_key']
180+
openai.api_key = skel["openai_api_key"]
219181

220182
blob, mime = conf.main_app.vi.file.read(key=filekey)
221183

222184
if not blob:
223185
raise errors.NotFound()
224-
resized_image_bytes = get_resized_image_bytes(blob)
186+
resized_image_bytes = self._get_resized_image_bytes(blob)
225187
base64_image = base64.b64encode(resized_image_bytes).decode("utf-8")
226-
prompt = f"use the following json data as additional information to describe the image: {re.sub(r'[^a-zA-Z0-9 _-]', '', context)}\n\n" + prompt
188+
prompt = f"use the following json data as additional information to describe the image: {re.sub(r"[^a-zA-Z0-9 _-]", "", context)}\n\n" + prompt
227189
content = [
228190
{
229191
"type": "text",
@@ -243,7 +205,7 @@ def get_resized_image_bytes(image, target_pixel_count=100_000,
243205
]
244206
try:
245207
completion = openai.chat.completions.create(
246-
model=skel['openai_model'],
208+
model=skel["openai_model"],
247209
messages=[
248210
{
249211
"role": "user",
@@ -256,5 +218,42 @@ def get_resized_image_bytes(image, target_pixel_count=100_000,
256218
logging.error(e)
257219
raise errors.PreconditionFailed(e.code)
258220

221+
def _get_resized_image_bytes(
222+
self,
223+
image,
224+
target_pixel_count=100_000,
225+
jpeg_quality=50,
226+
):
227+
# assert 0 <= jpeg_quality <= 100, "jpeg_quality must be between 0 and 100"
228+
if isinstance(image, bytes):
229+
image = io.BytesIO(image)
230+
231+
if not isinstance(image, (io.TextIOBase, io.BufferedIOBase, io.RawIOBase, io.IOBase)):
232+
raise ValueError("image must be file-like or bytes")
233+
pillow_image = PIL.Image.open(image)
234+
if pillow_image.format in ["PNG", "SVG", "WEBP"]:
235+
jpeg_image = io.BytesIO()
236+
pillow_image.convert("RGB").save(jpeg_image, "JPEG")
237+
jpeg_image.seek(0)
238+
pillow_image = PIL.Image.open(jpeg_image)
239+
240+
original_img_total_pixels = pillow_image.width * pillow_image.height
241+
side_ratio_to_n_pixels = (target_pixel_count / original_img_total_pixels) ** 0.5
242+
new_width = round(pillow_image.width * side_ratio_to_n_pixels)
243+
new_height = round(pillow_image.height * side_ratio_to_n_pixels)
244+
245+
if new_height > pillow_image.height or new_width > pillow_image.width:
246+
resized_img = pillow_image
247+
else:
248+
resized_img = pillow_image.resize(
249+
(new_width, new_height),
250+
PIL.Image.Resampling.LANCZOS
251+
)
252+
253+
result_bio = io.BytesIO()
254+
resized_img.save(result_bio, "jpeg", quality=jpeg_quality)
255+
result_bio.seek(0)
256+
return result_bio.read()
257+
259258

260259
Assistant.json = True

0 commit comments

Comments
 (0)