Skip to content

Commit 84b956f

Browse files
committed
Add possibility to use multimodal messages in chat domain
1 parent 50a6708 commit 84b956f

File tree

8 files changed

+222
-12
lines changed

8 files changed

+222
-12
lines changed

examples/async/chat/example.png

23.4 KB
Loading

examples/async/chat/multimodal.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import base64
7+
import pathlib
8+
9+
from yandex_cloud_ml_sdk import AsyncYCloudML
10+
11+
12+
def get_image_base64():
13+
image_path = pathlib.Path(__file__).parent / 'example.png'
14+
image_data = image_path.read_bytes()
15+
image_base64 = base64.b64encode(image_data)
16+
return image_base64.decode('utf-8')
17+
18+
19+
async def main() -> None:
20+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
21+
sdk.setup_default_logging()
22+
23+
# at this moment this is only model which supports image processing
24+
model = sdk.chat.completions('gemma-3-27b-it')
25+
26+
request = [
27+
# this is special kind of multimodal message which allows you to
28+
# mix image with text data;
29+
{
30+
'role': 'user',
31+
'content': [
32+
{
33+
'type': 'text', 'text': "What is depicted in the following image",
34+
},
35+
{
36+
'type': 'image_url',
37+
'image_url': {
38+
'url': f'data:image/png;base64,{get_image_base64()}'
39+
}
40+
}
41+
]
42+
}
43+
]
44+
45+
result = await model.run(request)
46+
47+
print(result.text)
48+
49+
50+
if __name__ == '__main__':
51+
asyncio.run(main())

examples/sync/chat/example.png

23.4 KB
Loading

examples/sync/chat/multimodal.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import base64
6+
import pathlib
7+
8+
from yandex_cloud_ml_sdk import YCloudML
9+
10+
11+
def get_image_base64():
12+
image_path = pathlib.Path(__file__).parent / 'example.png'
13+
image_data = image_path.read_bytes()
14+
image_base64 = base64.b64encode(image_data)
15+
return image_base64.decode('utf-8')
16+
17+
18+
def main() -> None:
19+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
20+
sdk.setup_default_logging()
21+
22+
# at this moment this is only model which supports image processing
23+
model = sdk.chat.completions('gemma-3-27b-it')
24+
25+
request = [
26+
# this is special kind of multimodal message which allows you to
27+
# mix image with text data;
28+
{
29+
'role': 'user',
30+
'content': [
31+
{
32+
'type': 'text', 'text': "What is depicted in the following image",
33+
},
34+
{
35+
'type': 'image_url',
36+
'image_url': {
37+
'url': f'data:image/png;base64,{get_image_base64()}'
38+
}
39+
}
40+
]
41+
}
42+
]
43+
44+
result = model.run(request)
45+
46+
print(result.text)
47+
48+
49+
if __name__ == '__main__':
50+
main()

src/yandex_cloud_ml_sdk/_chat/completions/message.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable
4-
from typing import TypedDict, Union, cast
3+
from collections.abc import Iterable, Sequence
4+
from typing import Literal, TypedDict, Union, cast
55

66
from typing_extensions import NotRequired, Required
77

@@ -20,7 +20,26 @@ class ChatFunctionResultMessageDict(TypedDict):
2020
content: Required[str]
2121

2222

23-
ChatCompletionsMessageType = Union[MessageType, ChatFunctionResultMessageDict, MessageInputType]
23+
class ImageUrlDict(TypedDict):
24+
url: str
25+
26+
27+
class ImageUrlContent(TypedDict):
28+
type: Literal['image_url']
29+
image_url: ImageUrlDict
30+
31+
32+
class TextContent(TypedDict):
33+
type: Literal['text']
34+
text: str
35+
36+
37+
class MultimodalMessageDict(TypedDict):
38+
role: NotRequired[str]
39+
content: Sequence[ImageUrlDict | TextContent]
40+
41+
42+
ChatCompletionsMessageType = Union[MessageType, ChatFunctionResultMessageDict, MessageInputType, MultimodalMessageDict]
2443
ChatMessageInputType = Union[ChatCompletionsMessageType, Iterable[ChatCompletionsMessageType]]
2544

2645

@@ -43,41 +62,46 @@ def message_to_json(message: ChatCompletionsMessageType, tool_name_ids: dict[str
4362
"content": message.text,
4463
"role": message.role,
4564
}
65+
4666
if isinstance(message, dict):
47-
text = message.get('text') or message.get('content', '')
67+
role: str | None = message.get('role')
68+
content: Sequence | str | None = message.get('content') # type: ignore[assignment]
69+
if isinstance(content, Sequence):
70+
return {
71+
'role': role or 'user',
72+
'content': list(content),
73+
}
74+
75+
text: str | None = message.get('text') or content # type: ignore[assignment]
4876
assert isinstance(text, str)
4977

5078
if tool_call_id := message.get('tool_call_id'):
5179
assert isinstance(tool_call_id, str)
5280
message = cast(ChatFunctionResultMessageDict, message)
53-
role = message.get('role', 'tool')
5481
return {
55-
'role': role,
82+
'role': role or 'tool',
5683
'content': text,
5784
'tool_call_id': tool_call_id,
5885
}
5986

6087
if tool_calls := message.get('tool_calls'):
6188
tool_calls = cast(JsonObject, tool_calls)
62-
role = message.get('role', 'assistant')
6389
return {
6490
'tool_calls': tool_calls,
65-
'role': role,
91+
'role': role or 'assistant',
6692
}
6793

6894
if text:
6995
message = cast(TextMessageDict, message)
70-
role = message.get('role', 'user')
7196
return {
7297
'content': text,
73-
'role': role
98+
'role': role or 'user'
7499
}
75100

76101
if tool_results := message.get('tool_results'):
77102
assert isinstance(tool_results, list)
78103
message = cast(FunctionResultMessageDict, message)
79104

80-
role = message.get('role', 'tool')
81105
result: list[JsonObject] = []
82106
for tool_result in tool_results:
83107
tool_result = cast(ToolResultDictType, tool_result)
@@ -91,7 +115,7 @@ def message_to_json(message: ChatCompletionsMessageType, tool_name_ids: dict[str
91115
)
92116

93117
result.append({
94-
'role': role,
118+
'role': role or 'tool',
95119
'content': content,
96120
'tool_call_id': id_,
97121
})

tests/chat/cassettes/test_completions/test_multimodal.yaml

Lines changed: 56 additions & 0 deletions
Large diffs are not rendered by default.

tests/chat/example.png

23.4 KB
Loading

tests/chat/test_completions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import base64
34
import json
5+
import pathlib
46
from typing import cast
57

68
import pytest
@@ -359,3 +361,30 @@ async def test_tool_choice(async_sdk: AsyncYCloudML, tool, schema) -> None:
359361
model = model.configure(tool_choice=None) # type: ignore[arg-type]
360362
result = await model.run(message)
361363
assert result.status.name == 'TOOL_CALLS'
364+
365+
366+
async def test_multimodal(async_sdk: AsyncYCloudML) -> None:
367+
model = async_sdk.chat.completions('gemma-3-27b-it')
368+
image_path = pathlib.Path(__file__).parent / 'example.png'
369+
image_data = image_path.read_bytes()
370+
image_base64 = base64.b64encode(image_data)
371+
image = image_base64.decode('utf-8')
372+
373+
request = [
374+
{
375+
'role': 'user',
376+
'content': [
377+
{
378+
'type': 'text', 'text': "What is depicted in the following image",
379+
},
380+
{
381+
'type': 'image_url',
382+
'image_url': {
383+
'url': f'data:image/png;base64,{image}'
384+
}
385+
}
386+
]
387+
}
388+
]
389+
result = await model.run(request)
390+
assert 'bricks' in result.text

0 commit comments

Comments
 (0)