Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/async/text_classifiers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk = AsyncYCloudML(folder_id='yc.fomo.storage.prod.service')
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this folder_id in the code.
I also don't like old version(

It is okay for now, but I will bump the related ticket

sdk.setup_default_logging()

model = sdk.models.text_classifiers('cls://b1ghsjum2v37c2un8h64/bt14f74au2ap3q0f9ou4')
model = sdk.models.text_classifiers(model_name='yandexgpt-lite', model_version='rc@tamrap1sjscq6e9flit3p')

# result will contain predictions with a predefined classes
# and most powerful prediction will be "mathematics": 0.92
Expand All @@ -20,6 +20,7 @@ async def main() -> None:
for prediction in result:
print(prediction)

print("f{result.input_tokens=}")

if __name__ == '__main__':
asyncio.run(main())
4 changes: 3 additions & 1 deletion examples/async/text_classifiers/run_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk = AsyncYCloudML(folder_id='yc.fomo.storage.prod.service')
sdk.setup_default_logging()

model = sdk.models.text_classifiers("yandexgpt").configure(
Expand Down Expand Up @@ -47,6 +47,8 @@ async def main() -> None:
for prediction in result:
print(prediction)

print("f{result.input_tokens=}")


if __name__ == '__main__':
asyncio.run(main())
5 changes: 3 additions & 2 deletions examples/sync/text_classifiers/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@


def main() -> None:
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk = YCloudML(folder_id='yc.fomo.storage.prod.service')
sdk.setup_default_logging()

model = sdk.models.text_classifiers('cls://b1ghsjum2v37c2un8h64/bt14f74au2ap3q0f9ou4')
model = sdk.models.text_classifiers(model_name='yandexgpt-lite', model_version='rc@tamrap1sjscq6e9flit3p')

# result will contain predictions with a predefined classes
# and most powerful prediction will be "mathematics": 0.92
Expand All @@ -18,6 +18,7 @@ def main() -> None:
for prediction in result:
print(prediction)

print("f{result.input_tokens=}")

if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion examples/sync/text_classifiers/run_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def main() -> None:
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')
sdk = YCloudML(folder_id='yc.fomo.storage.prod.service')
sdk.setup_default_logging()

model = sdk.models.text_classifiers("yandexgpt").configure(
Expand Down Expand Up @@ -45,6 +45,7 @@ def main() -> None:
for prediction in result:
print(prediction)

print("f{result.input_tokens=}")

if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions src/yandex_cloud_ml_sdk/_models/text_classifiers/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
class TextClassifiersModelResultBase(BaseResult, Sequence, Generic[TextClassificationResponseT]):
predictions: tuple[TextClassificationLabel, ...]
model_version: str
#: Number of input tokens provided to the model.
input_tokens: int

@classmethod
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument
Expand All @@ -42,6 +44,7 @@ def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: d
return cls(
predictions=predictions,
model_version=proto.model_version,
input_tokens = proto.input_tokens
)

def __len__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def patch_operation(request, monkeypatch):

@pytest.fixture(name='folder_id')
def fixture_folder_id():
return 'b1ghsjum2v37c2un8h64'
return 'yc.fomo.storage.prod.service'


@pytest.fixture(name='servicers')
Expand Down
137 changes: 113 additions & 24 deletions tests/models/cassettes/test_text_classifiers/test_run.gprc.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
"module": "yandex.cloud.endpoint.api_endpoint_service_pb2",
"message": {
"endpoints": [
{
"id": "ai-assistants",
"address": "assistant.api.cloud.yandex.net:443"
},
{
"id": "ai-files",
"address": "assistant.api.cloud.yandex.net:443"
},
{
"id": "ai-foundation-models",
"address": "llm.api.cloud.yandex.net:443"
Expand Down Expand Up @@ -71,6 +79,10 @@
"id": "backup",
"address": "backup.api.cloud.yandex.net:443"
},
{
"id": "baremetal",
"address": "baremetal.api.cloud.yandex.net:443"
},
{
"id": "billing",
"address": "billing.api.cloud.yandex.net:443"
Expand All @@ -91,6 +103,26 @@
"id": "certificate-manager-data",
"address": "data.certificate-manager.api.cloud.yandex.net:443"
},
{
"id": "certificate-manager-private-ca",
"address": "private-ca.certificate-manager.api.cloud.yandex.net:443"
},
{
"id": "certificate-manager-private-ca-data",
"address": "data.private-ca.certificate-manager.api.cloud.yandex.net:443"
},
{
"id": "cic",
"address": "cic.api.cloud.yandex.net:443"
},
{
"id": "cloud-registry",
"address": "registry.api.cloud.yandex.net:443"
},
{
"id": "cloudapps",
"address": "cloudapps.api.cloud.yandex.net:443"
},
{
"id": "cloudbackup",
"address": "backup.api.cloud.yandex.net:443"
Expand All @@ -99,6 +131,10 @@
"id": "clouddesktops",
"address": "clouddesktops.api.cloud.yandex.net:443"
},
{
"id": "cloudrouter",
"address": "cloudrouter.api.cloud.yandex.net:443"
},
{
"id": "cloudvideo",
"address": "video.api.cloud.yandex.net:443"
Expand Down Expand Up @@ -135,6 +171,18 @@
"id": "endpoint",
"address": "api.cloud.yandex.net:443"
},
{
"id": "fomo-dataset",
"address": "fomo-dataset.api.cloud.yandex.net:443"
},
{
"id": "fomo-tuning",
"address": "fomo-tuning.api.cloud.yandex.net:443"
},
{
"id": "gitlab",
"address": "gitlab.api.cloud.yandex.net:443"
},
{
"id": "iam",
"address": "iam.api.cloud.yandex.net:443"
Expand Down Expand Up @@ -163,6 +211,10 @@
"id": "kms-crypto",
"address": "kms.yandex:443"
},
{
"id": "kspm",
"address": "kspm.api.cloud.yandex.net:443"
},
{
"id": "load-balancer",
"address": "load-balancer.api.cloud.yandex.net:443"
Expand Down Expand Up @@ -195,6 +247,10 @@
"id": "logging",
"address": "logging.api.cloud.yandex.net:443"
},
{
"id": "managed-airflow",
"address": "airflow.api.cloud.yandex.net:443"
},
{
"id": "managed-clickhouse",
"address": "mdb.api.cloud.yandex.net:443"
Expand All @@ -215,6 +271,10 @@
"id": "managed-kubernetes",
"address": "mks.api.cloud.yandex.net:443"
},
{
"id": "managed-metastore",
"address": "metastore.api.cloud.yandex.net:443"
},
{
"id": "managed-mongodb",
"address": "mdb.api.cloud.yandex.net:443"
Expand All @@ -235,14 +295,34 @@
"id": "managed-redis",
"address": "mdb.api.cloud.yandex.net:443"
},
{
"id": "managed-spark",
"address": "spark.api.cloud.yandex.net:443"
},
{
"id": "managed-spqr",
"address": "mdb.api.cloud.yandex.net:443"
},
{
"id": "managed-sqlserver",
"address": "mdb.api.cloud.yandex.net:443"
},
{
"id": "managed-trino",
"address": "trino.api.cloud.yandex.net:443"
},
{
"id": "managed-ytsaurus",
"address": "ytsaurus.api.cloud.yandex.net:443"
},
{
"id": "marketplace",
"address": "marketplace.api.cloud.yandex.net:443"
},
{
"id": "marketplace-pim",
"address": "marketplace.api.cloud.yandex.net:443"
},
{
"id": "mdb-clickhouse",
"address": "mdb.api.cloud.yandex.net:443"
Expand All @@ -267,6 +347,10 @@
"id": "mdb-redis",
"address": "mdb.api.cloud.yandex.net:443"
},
{
"id": "mdb-spqr",
"address": "mdb.api.cloud.yandex.net:443"
},
{
"id": "mdbproxy",
"address": "mdbproxy.api.cloud.yandex.net:443"
Expand All @@ -287,6 +371,14 @@
"id": "organizationmanager",
"address": "organization-manager.api.cloud.yandex.net:443"
},
{
"id": "quota-manager",
"address": "quota-manager.api.cloud.yandex.net:443"
},
{
"id": "quotamanager",
"address": "quota-manager.api.cloud.yandex.net:443"
},
{
"id": "resource-manager",
"address": "resource-manager.api.cloud.yandex.net:443"
Expand All @@ -295,6 +387,10 @@
"id": "resourcemanager",
"address": "resource-manager.api.cloud.yandex.net:443"
},
{
"id": "searchapi",
"address": "searchapi.api.cloud.yandex.net:443"
},
{
"id": "serialssh",
"address": "serialssh.cloud.yandex.net:9600"
Expand All @@ -307,6 +403,10 @@
"id": "serverless-containers",
"address": "serverless-containers.api.cloud.yandex.net:443"
},
{
"id": "serverless-eventrouter",
"address": "serverless-eventrouter.api.cloud.yandex.net:443"
},
{
"id": "serverless-functions",
"address": "serverless-functions.api.cloud.yandex.net:443"
Expand All @@ -319,6 +419,14 @@
"id": "serverless-triggers",
"address": "serverless-triggers.api.cloud.yandex.net:443"
},
{
"id": "serverless-workflows",
"address": "serverless-workflows.api.cloud.yandex.net:443"
},
{
"id": "serverlesseventrouter-events",
"address": "events.eventrouter.serverless.yandexcloud.net:443"
},
{
"id": "smart-captcha",
"address": "smartcaptcha.api.cloud.yandex.net:443"
Expand Down Expand Up @@ -356,7 +464,7 @@
"cls": "TextClassificationRequest",
"module": "yandex.cloud.ai.foundation_models.v1.text_classification.text_classification_service_pb2",
"message": {
"modelUri": "cls://b1ghsjum2v37c2un8h64/bt14f74au2ap3q0f9ou4",
"modelUri": "cls://yc.fomo.storage.prod.service/yandexgpt-lite/rc@tamrap1sjscq6e9flit3p",
"text": "hello"
}
},
Expand All @@ -366,30 +474,11 @@
"message": {
"predictions": [
{
"label": "computer_science",
"confidence": 0.2230224609375
},
{
"label": "mathematics",
"confidence": 0.0177764892578125
},
{
"label": "physics",
"confidence": 0.03289794921875
},
{
"label": "quantitative_biology",
"confidence": 0.11932373046875
},
{
"label": "quantitative_finance",
"confidence": 0.01348114013671875
},
{
"label": "statistics",
"confidence": 0.044097900390625
"label": "\u043e\u043f\u0435\u0440\u0430\u0442\u043e\u0440",
"confidence": 0.83447265625
}
]
],
"inputTokens": "2"
}
}
}
Expand Down
Loading