Skip to content

Commit 43bdeb0

Browse files
authored
Add docstrings for _models/text_classifiers (#120)
1 parent ed3d3f3 commit 43bdeb0

File tree

5 files changed

+111
-1
lines changed

5 files changed

+111
-1
lines changed

src/yandex_cloud_ml_sdk/_models/text_classifiers/function.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,36 @@
33
from typing_extensions import override
44

55
from yandex_cloud_ml_sdk._types.function import BaseModelFunction, ModelTypeT
6+
from yandex_cloud_ml_sdk._utils.doc import doc_from
67

78
from .model import AsyncTextClassifiersModel, TextClassifiersModel
89

910

1011
class BaseTextClassifiers(BaseModelFunction[ModelTypeT]):
12+
"""A class for text classifiers.
13+
14+
It provides a common interface for text classification models and
15+
constructs the model URI based on the provided model name and version.
16+
"""
1117
@override
1218
def __call__(
1319
self,
1420
model_name: str,
1521
*,
1622
model_version: str = 'latest',
1723
):
24+
"""Call the text classification model.
25+
26+
Constructs the URI for the model based on the provided model's name
27+
and version. If the name contains ``://``, it is treated as a
28+
complete URI. Otherwise, it looks up the model name in
29+
the well-known names dictionary. But after this, in any case,
30+
we construct a URI in the form ``cls://<folder_id>/<model>/<version>``.
31+
32+
:param model_name: the name or URI of the model to call.
33+
:param model_version: the version of the model to be used.
34+
Defaults to 'latest'.
35+
"""
1836
if '://' in model_name:
1937
uri = model_name
2038
else:
@@ -27,9 +45,10 @@ def __call__(
2745
)
2846

2947

48+
@doc_from(BaseTextClassifiers)
3049
class TextClassifiers(BaseTextClassifiers):
3150
_model_type = TextClassifiersModel
3251

33-
52+
@doc_from(BaseTextClassifiers)
3453
class AsyncTextClassifiers(BaseTextClassifiers):
3554
_model_type = AsyncTextClassifiersModel

src/yandex_cloud_ml_sdk/_models/text_classifiers/model.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from yandex_cloud_ml_sdk._types.tuning.optimizers import BaseOptimizer
2424
from yandex_cloud_ml_sdk._types.tuning.schedulers import BaseScheduler
2525
from yandex_cloud_ml_sdk._types.tuning.tuning_types import BaseTuningType
26+
from yandex_cloud_ml_sdk._utils.doc import doc_from
2627
from yandex_cloud_ml_sdk._utils.sync import run_sync
2728

2829
from .config import TextClassifiersModelConfig
@@ -40,6 +41,11 @@ class BaseTextClassifiersModel(
4041
TuningTaskTypeT
4142
],
4243
):
44+
"""
45+
A class for text classifiers models.
46+
It provides the foundational structure for building text classification models,
47+
including configuration and execution of classification tasks.
48+
"""
4349
_config_type = TextClassifiersModelConfig
4450
_result_type = TextClassifiersModelResultBase
4551
_tuning_params_type = TextClassifiersModelTuneParams
@@ -134,6 +140,7 @@ async def _run_few_shot(
134140
return FewShotTextClassifiersModelResult._from_proto(proto=response, sdk=self._sdk)
135141

136142

143+
@doc_from(BaseTextClassifiersModel)
137144
class AsyncTextClassifiersModel(BaseTextClassifiersModel[AsyncTuningTask['AsyncTextClassifiersModel']]):
138145
_tune_operation_type = AsyncTuningTask['AsyncTextClassifiersModel']
139146

@@ -143,6 +150,18 @@ async def run(
143150
*,
144151
timeout: float = 60,
145152
) -> TextClassifiersModelResultBase:
153+
"""Execute the text classification on the provided input text.
154+
155+
If only labels are specified, apply a zero-shot classifier.
156+
If samples are also specified - it is a case of the few-shot classifier.
157+
If nothing is specified, use the classify method, but it is only available for pre-trained models.
158+
159+
Read more about the classifiers in `the documentation <https://yandex.cloud/docs/foundation-models/concepts/classifier/>`_.
160+
161+
:param text: the input text to classify.
162+
:param timeout: the timeout, or the maximum time to wait for the request to complete in seconds.
163+
Defaults to 60 seconds.
164+
"""
146165
return await self._run(
147166
text=text,
148167
timeout=timeout
@@ -167,6 +186,24 @@ async def tune_deferred(
167186
optimizer: UndefinedOr[BaseOptimizer] = UNDEFINED,
168187
timeout: float = 60,
169188
) -> AsyncTuningTask['AsyncTextClassifiersModel']:
189+
"""Initiate a deferred tuning process for the model.
190+
191+
:param train_datasets: the dataset objects and/or dataset ids used for training of the model.
192+
:param validation_datasets: the dataset objects and/or dataset ids used for validation of the model.
193+
:param classification_type: the type of classification to perform during tuning (multilabel, multiclass, or binary).
194+
:param name: the name of the tuning task.
195+
:param description: the description of the tuning task.
196+
:param labels: labels for the tuning task.
197+
:param timeout: the timeout, or the maximum time to wait for the request to complete in seconds.
198+
Defaults to 60 seconds.
199+
:param seed: a random seed for reproducibility.
200+
:param lr: a learning rate for tuning.
201+
:param n_samples: a number of samples for tuning.
202+
:param additional_arguments: additional arguments for tuning.
203+
:param tuning_type: a type of tuning to be applied.
204+
:param scheduler: a scheduler for tuning.
205+
:param optimizer: an optimizer for tuning.
206+
"""
170207
return await self._tune_deferred(
171208
train_datasets=train_datasets,
172209
validation_datasets=validation_datasets,
@@ -205,6 +242,28 @@ async def tune(
205242
poll_timeout: int = 72 * 60 * 60,
206243
poll_interval: float = 60,
207244
) -> Self:
245+
"""Tune the model with the specified training datasets and parameters.
246+
247+
:param train_datasets: the dataset objects and/or dataset ids used for training of the model.
248+
:param validation_datasets: the dataset objects and/or dataset ids used for validation of the model.
249+
:param classification_type: the type of classification to perform during tuning (multilabel, multiclass, or binary).
250+
:param name: the name of the tuning task.
251+
:param description: the description of the tuning task.
252+
:param labels: labels for the tuning task.
253+
:param timeout: the timeout, or the maximum time to wait for the request to complete in seconds.
254+
Defaults to 60 seconds.
255+
:param seed: a random seed for reproducibility.
256+
:param lr: a learning rate for tuning.
257+
:param n_samples: a number of samples for tuning.
258+
:param additional_arguments: additional arguments for tuning.
259+
:param tuning_type: a type of tuning to be applied.
260+
:param scheduler: a scheduler for tuning.
261+
:param optimizer: an optimizer for tuning.
262+
:param poll_timeout: the maximum time to wait while polling for completion of the tuning task.
263+
Defaults to 259200 seconds (72 hours).
264+
:param poll_interval: the interval between polling attempts during the tuning process.
265+
Defaults to 60 seconds.
266+
"""
208267
return await self._tune(
209268
train_datasets=train_datasets,
210269
validation_datasets=validation_datasets,
@@ -230,16 +289,24 @@ async def attach_tune_deferred(
230289
*,
231290
timeout: float = 60
232291
) -> AsyncTuningTask['AsyncTextClassifiersModel']:
292+
"""Attach a deferred tuning task using its task ID.
293+
294+
:param task_id: the ID of the deferred tuning task to attach to.
295+
:param timeout: the timeout, or the maximum time to wait for the request to complete in seconds.
296+
Defaults to 60 seconds.
297+
"""
233298
return await self._attach_tune_deferred(task_id=task_id, timeout=timeout)
234299

235300

301+
@doc_from(BaseTextClassifiersModel)
236302
class TextClassifiersModel(BaseTextClassifiersModel[TuningTask['TextClassifiersModel']]):
237303
_tune_operation_type = TuningTask['TextClassifiersModel']
238304
__run = run_sync(BaseTextClassifiersModel._run)
239305
__tune_deferred = run_sync(BaseTextClassifiersModel._tune_deferred)
240306
__tune = run_sync(BaseTextClassifiersModel._tune)
241307
__attach_tune_deferred = run_sync(BaseTextClassifiersModel._attach_tune_deferred)
242308

309+
@doc_from(AsyncTextClassifiersModel.run)
243310
def run(
244311
self,
245312
text: str,
@@ -252,6 +319,7 @@ def run(
252319
)
253320

254321
# pylint: disable=too-many-locals
322+
@doc_from(AsyncTextClassifiersModel.tune_deferred)
255323
def tune_deferred(
256324
self,
257325
train_datasets: TuningDatasetsType,
@@ -289,6 +357,7 @@ def tune_deferred(
289357
return cast(TuningTask[TextClassifiersModel], result)
290358

291359
# pylint: disable=too-many-locals
360+
@doc_from(AsyncTextClassifiersModel.tune)
292361
def tune(
293362
self,
294363
train_datasets: TuningDatasetsType,
@@ -328,6 +397,7 @@ def tune(
328397
poll_interval=poll_interval,
329398
)
330399

400+
@doc_from(AsyncTextClassifiersModel.attach_tune_deferred)
331401
def attach_tune_deferred(self, task_id: str, *, timeout: float = 60) -> TuningTask[TextClassifiersModel]:
332402
return cast(
333403
TuningTask[TextClassifiersModel],

src/yandex_cloud_ml_sdk/_models/text_classifiers/result.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626

2727
@dataclass(frozen=True)
2828
class TextClassifiersModelResultBase(BaseResult, Sequence, Generic[TextClassificationResponseT]):
29+
"""A class for text classifiers model results.
30+
It represents the common structure for the results returned by text classification models.
31+
"""
32+
#: a tuple containing the predicted labels
2933
predictions: tuple[TextClassificationLabel, ...]
34+
#: the version of the model used for prediction
3035
model_version: str
3136
#: Number of input tokens provided to the model.
3237
input_tokens: int

src/yandex_cloud_ml_sdk/_models/text_classifiers/tune_params.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
@dataclass(frozen=True)
1717
class TextClassifiersModelTuneParams(BaseTuningParams):
18+
"""This class encapsulates the parameters used for tuning text classification models,
19+
supporting both multiclass and multilabel classification types.
20+
"""
1821
@property
1922
def _proto_tuning_params_type(
2023
self
@@ -43,8 +46,13 @@ def __post_init__(self):
4346
f'classification_type must be {ClassificationTuningTypes}, got {self.classification_type}'
4447
)
4548

49+
#: the type of classification to be used (should be one of 'multilabel', 'multiclass', or 'binary'.)
4650
classification_type: ClassificationTuningTypes | None = None
51+
#: random seed for reproducibility
4752
seed: int | None = None
53+
#: a learning rate for the tuning process
4854
lr: float | None = None
55+
#: a number of samples to use for tuning
4956
n_samples: int | None = None
57+
#: any additional arguments required for tuning
5058
additional_arguments: str | None = None

src/yandex_cloud_ml_sdk/_models/text_classifiers/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66

77
@dataclass(frozen=True)
88
class TextClassificationLabel(Mapping):
9+
"""This class represents a label for text classification
10+
with an associated confidence score.
11+
"""
12+
#: the label for the classification
913
label: str
14+
#: the confidence score associated with the label
1015
confidence: float
1116

1217
def __getitem__(self, key):
@@ -20,5 +25,8 @@ def __len__(self):
2025

2126

2227
class TextClassificationSample(TypedDict):
28+
"""This class represents a sample of text for classification."""
29+
#: the text to be classified
2330
text: str
31+
#: the expected label for the classification
2432
label: str

0 commit comments

Comments
 (0)