forked from yandex-cloud/yandex-ai-studio-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresult.py
More file actions
72 lines (53 loc) · 2.11 KB
/
result.py
File metadata and controls
72 lines (53 loc) · 2.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar, cast, overload
from typing_extensions import Self
# pylint: disable-next=no-name-in-module
from yandex.cloud.ai.foundation_models.v1.text_classification.text_classification_service_pb2 import (
FewShotTextClassificationResponse, TextClassificationResponse
)
from yandex_cloud_ml_sdk._types.result import BaseResult, ProtoMessage
from .types import TextClassificationLabel
if TYPE_CHECKING:
from yandex_cloud_ml_sdk._sdk import BaseSDK
TextClassificationResponseT = TypeVar(
'TextClassificationResponseT',
TextClassificationResponse,
FewShotTextClassificationResponse
)
@dataclass(frozen=True)
class TextClassifiersModelResultBase(BaseResult, Sequence, Generic[TextClassificationResponseT]):
predictions: tuple[TextClassificationLabel, ...]
model_version: str
#: doc Number of input tokens provided to the model.
input_tokens: int
@classmethod
def _from_proto(cls, *, proto: ProtoMessage, sdk: BaseSDK) -> Self: # pylint: disable=unused-argument
proto = cast(TextClassificationResponseT, proto)
predictions = tuple(
TextClassificationLabel(
label=p.label,
confidence=p.confidence
) for p in proto.predictions
)
return cls(
predictions=predictions,
model_version=proto.model_version,
input_tokens = proto.input_tokens
)
def __len__(self) -> int:
return len(self.predictions)
@overload
def __getitem__(self, index: int, /) -> TextClassificationLabel:
pass
@overload
def __getitem__(self, slice_: slice, /) -> tuple[TextClassificationLabel, ...]:
pass
def __getitem__(self, index, /):
return self.predictions[index]
@dataclass(frozen=True)
class TextClassifiersModelResult(TextClassifiersModelResultBase[TextClassificationResponse]):
pass
@dataclass(frozen=True)
class FewShotTextClassifiersModelResult(TextClassifiersModelResultBase[FewShotTextClassificationResponse]):
pass