Skip to content

Commit 363776a

Browse files
authored
Add tuning for embeddings (#69)
1 parent e227e93 commit 363776a

File tree

10 files changed

+442
-6
lines changed

10 files changed

+442
-6
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import pathlib
7+
import uuid
8+
9+
from yandex_cloud_ml_sdk import AsyncYCloudML
10+
11+
12+
def local_path(path: str) -> pathlib.Path:
13+
return pathlib.Path(__file__).parent / path
14+
15+
16+
async def get_datasets(sdk, name, dataset_function):
17+
"""
18+
This function represents getting or creating datasets object.
19+
20+
In real life you could use just a datasets ids, for example:
21+
22+
```
23+
dataset = sdk.datasets.get("some_id")
24+
tuning_task = base_model.tune_deferred(
25+
"dataset_id",
26+
validation_datasets=dataset
27+
)
28+
```
29+
"""
30+
31+
async for dataset in sdk.datasets.list(status='READY', name_pattern=name):
32+
print(f'using old dataset {dataset=}')
33+
break
34+
else:
35+
print('no old datasets found, creating new one')
36+
dataset_draft = dataset_function.draft_from_path(
37+
path=local_path(f'{name}.jsonlines'),
38+
upload_format='jsonlines',
39+
name=name,
40+
)
41+
42+
dataset = await dataset_draft.upload()
43+
print(f'created new dataset {dataset=}')
44+
45+
return dataset, dataset
46+
47+
48+
async def main() -> None:
49+
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
50+
sdk.setup_default_logging()
51+
base_model = sdk.models.text_embeddings('yandexgpt-lite')
52+
53+
for name, tune_type, dataset_function in [
54+
('embeddings_pair', 'pair', sdk.datasets.text_embeddings_pair),
55+
('embeddings_triplet', 'triplet', sdk.datasets.text_embeddings_triplet),
56+
]:
57+
train_dataset, validation_dataset = await get_datasets(sdk, name, dataset_function)
58+
result = await base_model.run("hi")
59+
print(f'pretrain model inference result: {result}')
60+
61+
# `.tune(...)` is a shortcut for:
62+
# tuning_task = await base_model.tune_deferred(...)
63+
# new_model = await tuning_task.wait(...)
64+
# But it gives you less control on tune canceling and
65+
# reporting.
66+
new_model = await base_model.tune(
67+
train_dataset,
68+
validation_datasets=validation_dataset,
69+
embeddings_tune_type=tune_type,
70+
name=str(uuid.uuid4())
71+
)
72+
print(f'resulting {new_model}')
73+
74+
# you can save model.uri somewhere and reuse it later
75+
tuned_uri = new_model.uri
76+
model = sdk.models.text_embeddings(tuned_uri)
77+
result = await model.run("hi")
78+
print(f'posttrain model inference result: {result}')
79+
80+
81+
if __name__ == '__main__':
82+
asyncio.run(main())
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"anchor": "hello", "positive": "hi"}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"anchor": "hello", "positive": "hi", "negative": "bye"}

examples/sync/tuning/embeddings.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env python3
2+
3+
from __future__ import annotations
4+
5+
import pathlib
6+
import uuid
7+
8+
from yandex_cloud_ml_sdk import YCloudML
9+
10+
11+
def local_path(path: str) -> pathlib.Path:
12+
return pathlib.Path(__file__).parent / path
13+
14+
15+
def get_datasets(sdk, name, dataset_function):
16+
"""
17+
This function represents getting or creating datasets object.
18+
19+
In real life you could use just a datasets ids, for example:
20+
21+
```
22+
dataset = sdk.datasets.get("some_id")
23+
tuning_task = base_model.tune_deferred(
24+
"dataset_id",
25+
validation_datasets=dataset
26+
)
27+
```
28+
"""
29+
30+
for dataset in sdk.datasets.list(status='READY', name_pattern=name):
31+
print(f'using old dataset {dataset=}')
32+
break
33+
else:
34+
print('no old datasets found, creating new one')
35+
dataset_draft = dataset_function.draft_from_path(
36+
path=local_path(f'{name}.jsonlines'),
37+
upload_format='jsonlines',
38+
name=name,
39+
)
40+
41+
dataset = dataset_draft.upload()
42+
print(f'created new dataset {dataset=}')
43+
44+
return dataset, dataset
45+
46+
47+
def main() -> None:
48+
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
49+
sdk.setup_default_logging()
50+
base_model = sdk.models.text_embeddings('yandexgpt-lite')
51+
52+
for name, tune_type, dataset_function in [
53+
('embeddings_pair', 'pair', sdk.datasets.text_embeddings_pair),
54+
('embeddings_triplet', 'triplet', sdk.datasets.text_embeddings_triplet),
55+
]:
56+
train_dataset, validation_dataset = get_datasets(sdk, name, dataset_function)
57+
result = base_model.run("hi")
58+
print(f'pretrain model inference result: {result}')
59+
60+
# `.tune(...)` is a shortcut for:
61+
# tuning_task = base_model.tune_deferred(...)
62+
# new_model = tuning_task.wait(...)
63+
# But it gives you less control on tune canceling and
64+
# reporting.
65+
new_model = base_model.tune(
66+
train_dataset,
67+
validation_datasets=validation_dataset,
68+
embeddings_tune_type=tune_type,
69+
name=str(uuid.uuid4())
70+
)
71+
print(f'resulting {new_model}')
72+
73+
# you can save model.uri somewhere and reuse it later
74+
tuned_uri = new_model.uri
75+
model = sdk.models.text_embeddings(tuned_uri)
76+
result = model.run("hi")
77+
print(f'posttrain model inference result: {result}')
78+
79+
80+
if __name__ == '__main__':
81+
main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"anchor": "hello", "positive": "hi"}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"anchor": "hello", "positive": "hi", "negative": "bye"}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ classifiers = [
3333
requires-python = ">=3.9"
3434
dynamic = ["version"]
3535
dependencies = [
36-
"yandexcloud>=0.334.0",
36+
"yandexcloud>=0.335.0",
3737
"grpcio>=1.70.0",
3838
"get-annotations",
3939
"httpx>=0.27,<1",

0 commit comments

Comments
 (0)