Skip to content

Commit ce50e82

Browse files
authored
Merge pull request #127 from Aleph-Alpha/fix-optimized-prompt-parsing
Fix optimized prompt parsing
2 parents 04fcd22 + ba2cd25 commit ce50e82

File tree

4 files changed

+96
-26
lines changed

4 files changed

+96
-26
lines changed

Changelog.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
# Changelog
22

3-
# 3.2.4
3+
## next release
4+
5+
### Features
6+
7+
- Add `PromptTemplate` to support easy creation of multi-modal prompts
8+
9+
### Bugs
10+
11+
- Fix parsing of optimized prompt returned in a `CompletionResponse`
12+
13+
## 3.2.4
414

515
- Make sure `control_factor` gets passed along with `ExplanationRequest`
616

7-
# 3.2.3
17+
## 3.2.3
818

919
- Make sure model name gets passed along for async batch semnatic embed
1020

11-
# 3.2.2
21+
## 3.2.2
1222

1323
- Re-relase 3.2.1 again because of deployment issue
1424

15-
# 3.2.1
25+
## 3.2.1
1626

1727
- Add progress_bar option to batch semantic embedding API
1828
- Add batch_size option to batch semantic embedding API
@@ -35,7 +45,7 @@
3545

3646
### Bug fixes
3747

38-
- Add missing import of **PromptGranularity** in *__init__.py*.
48+
- Add missing import of **PromptGranularity** in _**init**.py_.
3949

4050
## 3.1.2
4151

aleph_alpha_client/completion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,19 @@ def from_json(json: Dict[str, Any]) -> "CompletionResult":
239239
class CompletionResponse(NamedTuple):
240240
model_version: str
241241
completions: Sequence[CompletionResult]
242-
optimized_prompt: Optional[Sequence[str]] = None
242+
optimized_prompt: Optional[Prompt] = None
243243

244244
@staticmethod
245245
def from_json(json: Dict[str, Any]) -> "CompletionResponse":
246+
optimized_prompt_json = json.get("optimized_prompt")
246247
return CompletionResponse(
247248
model_version=json["model_version"],
248249
completions=[
249250
CompletionResult.from_json(item) for item in json["completions"]
250251
],
251-
optimized_prompt=json.get("optimized_prompt"),
252+
optimized_prompt=Prompt.from_json(optimized_prompt_json)
253+
if optimized_prompt_json
254+
else None,
252255
)
253256

254257
def to_json(self) -> Mapping[str, Any]:

aleph_alpha_client/prompt.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def to_json(self) -> Mapping[str, Any]:
9999
"controls": [c.to_json() for c in self.controls],
100100
}
101101

102+
@staticmethod
103+
def from_json(json: Mapping[str, Any]) -> "Tokens":
104+
return Tokens(tokens=json["data"], controls=[])
105+
102106
@staticmethod
103107
def from_token_ids(token_ids: Sequence[int]) -> "Tokens":
104108
return Tokens(token_ids, [])
@@ -173,20 +177,23 @@ def to_json(self) -> Mapping[str, Any]:
173177
"controls": [control.to_json() for control in self.controls],
174178
}
175179

180+
@staticmethod
181+
def from_json(json: Mapping[str, Any]) -> "Text":
182+
return Text.from_text(json["data"])
183+
176184
@staticmethod
177185
def from_text(text: str) -> "Text":
178186
return Text(text, [])
179187

180188

181-
class Cropping:
189+
class Cropping(NamedTuple):
182190
"""
183191
Describes a quadratic crop of the file.
184192
"""
185193

186-
def __init__(self, upper_left_x: int, upper_left_y: int, size: int):
187-
self.upper_left_x = upper_left_x
188-
self.upper_left_y = upper_left_y
189-
self.size = size
194+
upper_left_x: int
195+
upper_left_y: int
196+
size: int
190197

191198

192199
class ImageControl(NamedTuple):
@@ -254,7 +261,7 @@ def to_json(self) -> Mapping[str, Any]:
254261
return payload
255262

256263

257-
class Image:
264+
class Image(NamedTuple):
258265
"""
259266
An image send as part of a prompt to a model. The image is represented as
260267
base64.
@@ -272,17 +279,11 @@ class Image:
272279
>>> image = Image.from_url(url)
273280
"""
274281

275-
def __init__(
276-
self,
277-
base_64: str,
278-
cropping: Optional[Cropping],
279-
controls: Sequence[ImageControl],
280-
):
281-
# We use a base_64 reperesentation, because we want to embed the image
282-
# into a prompt send in JSON.
283-
self.base_64 = base_64
284-
self.cropping = cropping
285-
self.controls: Sequence[ImageControl] = controls
282+
# We use a base_64 reperesentation, because we want to embed the image
283+
# into a prompt send in JSON.
284+
base_64: str
285+
cropping: Optional[Cropping]
286+
controls: Sequence[ImageControl]
286287

287288
@classmethod
288289
def from_image_source(
@@ -357,7 +358,9 @@ def from_url_with_cropping(
357358
return cls.from_bytes(bytes, cropping=cropping, controls=controls or [])
358359

359360
@classmethod
360-
def from_file(cls, path: Union[str, Path], controls: Optional[Sequence[ImageControl]] = None):
361+
def from_file(
362+
cls, path: Union[str, Path], controls: Optional[Sequence[ImageControl]] = None
363+
):
361364
"""
362365
Load an image from disk and prepare it to be used in a prompt
363366
If they are not provided then the image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop)
@@ -412,6 +415,10 @@ def to_json(self) -> Dict[str, Any]:
412415
"controls": [control.to_json() for control in self.controls],
413416
}
414417

418+
@staticmethod
419+
def from_json(json: Mapping[str, Any]) -> "Image":
420+
return Image(base_64=json["data"], cropping=None, controls=[])
421+
415422
def to_image(self) -> PILImage:
416423
return PIL.Image.open(io.BytesIO(base64.b64decode(self.base_64)))
417424

@@ -464,6 +471,29 @@ def from_tokens(
464471
def to_json(self) -> Sequence[Mapping[str, Any]]:
465472
return [_to_json(item) for item in self.items]
466473

474+
@staticmethod
475+
def from_json(items_json: Sequence[Mapping[str, Any]]) -> "Prompt":
476+
return Prompt(
477+
[
478+
item
479+
for item in (_prompt_item_from_json(item) for item in items_json)
480+
if item
481+
]
482+
)
483+
484+
485+
def _prompt_item_from_json(item: Mapping[str, Any]) -> Optional[PromptItem]:
486+
item_type = item.get("type")
487+
if item_type == "text":
488+
return Text.from_json(item)
489+
if item_type == "image":
490+
return Image.from_json(item)
491+
if item_type == "token_ids":
492+
return Tokens.from_json(item)
493+
# Skip item instead of raising an error to prevent failures of old clients
494+
# when item types are extended
495+
return None
496+
467497

468498
def _to_json(item: PromptItem) -> Mapping[str, Any]:
469499
if hasattr(item, "to_json"):

tests/test_complete.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import pytest
22
from aleph_alpha_client import AsyncClient, Client
33
from aleph_alpha_client.completion import CompletionRequest
4-
from aleph_alpha_client.prompt import ControlTokenOverlap, Prompt, Text, TextControl
4+
from aleph_alpha_client.prompt import (
5+
ControlTokenOverlap,
6+
Image,
7+
Prompt,
8+
Text,
9+
TextControl,
10+
Tokens,
11+
)
512

613
from tests.common import (
714
sync_client,
815
async_client,
916
model_name,
17+
prompt_image,
1018
)
1119

1220

@@ -72,3 +80,22 @@ def test_complete_with_token_ids(sync_client: Client, model_name: str):
7280

7381
assert len(response.completions) == 1
7482
assert response.model_version is not None
83+
84+
85+
@pytest.mark.system_test
86+
def test_complete_with_optimized_prompt(
87+
sync_client: Client, model_name: str, prompt_image: Image
88+
):
89+
prompt_text = " Hello World! "
90+
prompt_tokens = Tokens.from_token_ids([1, 2])
91+
request = CompletionRequest(
92+
prompt=Prompt([Text.from_text(prompt_text), prompt_image, prompt_tokens]),
93+
maximum_tokens=5,
94+
)
95+
96+
response = sync_client.complete(request, model=model_name)
97+
98+
assert response.optimized_prompt
99+
assert response.optimized_prompt.items[0] == Text.from_text(prompt_text.strip())
100+
assert response.optimized_prompt.items[2] == prompt_tokens
101+
assert isinstance(response.optimized_prompt.items[1], Image)

0 commit comments

Comments
 (0)