|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import asyncio |
| 6 | +import pathlib |
| 7 | +import uuid |
| 8 | + |
| 9 | +from yandex_cloud_ml_sdk import AsyncYCloudML |
| 10 | + |
| 11 | + |
| 12 | +def local_path(path: str) -> pathlib.Path: |
| 13 | + return pathlib.Path(__file__).parent / path |
| 14 | + |
| 15 | + |
| 16 | +async def get_datasets(sdk, name, dataset_function): |
| 17 | + """ |
| 18 | + This function represents getting or creating datasets object. |
| 19 | +
|
| 20 | + In real life you could use just a datasets ids, for example: |
| 21 | +
|
| 22 | + ``` |
| 23 | + dataset = sdk.datasets.get("some_id") |
| 24 | + tuning_task = base_model.tune_deferred( |
| 25 | + "dataset_id", |
| 26 | + validation_datasets=dataset |
| 27 | + ) |
| 28 | + ``` |
| 29 | + """ |
| 30 | + |
| 31 | + async for dataset in sdk.datasets.list(status='READY', name_pattern=name): |
| 32 | + print(f'using old dataset {dataset=}') |
| 33 | + break |
| 34 | + else: |
| 35 | + print('no old datasets found, creating new one') |
| 36 | + dataset_draft = dataset_function.draft_from_path( |
| 37 | + path=local_path(f'{name}.jsonlines'), |
| 38 | + upload_format='jsonlines', |
| 39 | + name=name, |
| 40 | + ) |
| 41 | + |
| 42 | + dataset = await dataset_draft.upload() |
| 43 | + print(f'created new dataset {dataset=}') |
| 44 | + |
| 45 | + return dataset, dataset |
| 46 | + |
| 47 | + |
| 48 | +async def main() -> None: |
| 49 | + sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64') |
| 50 | + sdk.setup_default_logging() |
| 51 | + base_model = sdk.models.text_embeddings('yandexgpt-lite') |
| 52 | + |
| 53 | + for name, tune_type, dataset_function in [ |
| 54 | + ('embeddings_pair', 'pair', sdk.datasets.text_embeddings_pair), |
| 55 | + ('embeddings_triplet', 'triplet', sdk.datasets.text_embeddings_triplet), |
| 56 | + ]: |
| 57 | + train_dataset, validation_dataset = await get_datasets(sdk, name, dataset_function) |
| 58 | + result = await base_model.run("hi") |
| 59 | + print(f'pretrain model inference result: {result}') |
| 60 | + |
| 61 | + # `.tune(...)` is a shortcut for: |
| 62 | + # tuning_task = await base_model.tune_deferred(...) |
| 63 | + # new_model = await tuning_task.wait(...) |
| 64 | + # But it gives you less control on tune canceling and |
| 65 | + # reporting. |
| 66 | + new_model = await base_model.tune( |
| 67 | + train_dataset, |
| 68 | + validation_datasets=validation_dataset, |
| 69 | + embeddings_tune_type=tune_type, |
| 70 | + name=str(uuid.uuid4()) |
| 71 | + ) |
| 72 | + print(f'resulting {new_model}') |
| 73 | + |
| 74 | + # you can save model.uri somewhere and reuse it later |
| 75 | + tuned_uri = new_model.uri |
| 76 | + model = sdk.models.text_embeddings(tuned_uri) |
| 77 | + result = await model.run("hi") |
| 78 | + print(f'posttrain model inference result: {result}') |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == '__main__': |
| 82 | + asyncio.run(main()) |
0 commit comments