Skip to content

Commit 198ea01

Browse files
authored
Add get_prediction method and add async support in create_predictions (#144)
* Add statistics to training and async bool to create_prediction * Add sync as argument and endpoint for get predictionId * Update version and changelog * add some tests * add missing function * change name of sync call * be consequent of the use of run_async
1 parent 11ebb2e commit 198ea01

File tree

5 files changed

+50
-5
lines changed

5 files changed

+50
-5
lines changed

Diff for: CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Changelog
22

3+
## Version 11.4.0 - 2024-11-21
4+
5+
- Added `get_prediction`
6+
- Support run_async argument for `create_predictions`
7+
- Support `statistics_last_n_days` for `get_training`
8+
39
## Version 11.3.0 - 2024-03-27
410

511
- Support profile argument when creating a Client

Diff for: las/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
__maintainer_email__ ='[email protected]'
88
__title__ = 'lucidtech-las'
99
__url__ = 'https://github.com/LucidtechAI/las-sdk-python'
10-
__version__ = '11.3.0'
10+
__version__ = '11.4.0'

Diff for: las/client.py

+27-3
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ def delete_dataset(self, dataset_id: str, delete_documents: bool = False) -> Dic
646646
647647
:param dataset_id: Id of the dataset
648648
:type dataset_id: str
649-
:param delete_documents: Set to true to delete documents in dataset before deleting dataset
649+
:param delete_documents: Set to True to delete documents in dataset before deleting dataset
650650
:type delete_documents: bool
651651
:return: Dataset response from REST API
652652
:rtype: dict
@@ -1342,20 +1342,23 @@ def create_training(
13421342
body.update(**optional_args)
13431343
return self._make_request(requests.post, f'/models/{model_id}/trainings', body=body)
13441344

1345-
def get_training(self, model_id: str, training_id: str) -> Dict:
1345+
def get_training(self, model_id: str, training_id: str, statistics_last_n_days: Optional[int] = None) -> Dict:
13461346
"""Get training, calls the GET /models/{modelId}/trainings/{trainingId} endpoint.
13471347
13481348
:param model_id: ID of the model
13491349
:type model_id: str
13501350
:param training_id: ID of the training
13511351
:type training_id: str
1352+
:param statistics_last_n_days: Integer between 1 and 30
1353+
:type statistics_last_n_days: int, optional
13521354
:return: Training response from REST API
13531355
:rtype: dict
13541356
13551357
:raises: :py:class:`~las.InvalidCredentialsException`, :py:class:`~las.TooManyRequestsException`,\
13561358
:py:class:`~las.LimitExceededException`, :py:class:`requests.exception.RequestException`
13571359
"""
1358-
return self._make_request(requests.get, f'/models/{model_id}/trainings/{training_id}')
1360+
params = {'statisticsLastNDays': statistics_last_n_days}
1361+
return self._make_request(requests.get, f'/models/{model_id}/trainings/{training_id}', params=params)
13591362

13601363
def list_trainings(self, model_id, *, max_results: Optional[int] = None, next_token: Optional[str] = None) -> Dict:
13611364
"""List trainings available, calls the GET /models/{modelId}/trainings endpoint.
@@ -1528,6 +1531,7 @@ def create_prediction(
15281531
training_id: Optional[str] = None,
15291532
preprocess_config: Optional[dict] = None,
15301533
postprocess_config: Optional[dict] = None,
1534+
run_async: Optional[bool] = None,
15311535
) -> Dict:
15321536
"""Create a prediction on a document using specified model, calls the POST /predictions endpoint.
15331537
@@ -1568,6 +1572,8 @@ def create_prediction(
15681572
{'strategy': 'BEST_N_PAGES', 'parameters': {'n': 3}}
15691573
{'strategy': 'BEST_N_PAGES', 'parameters': {'n': 3, 'collapse': False}}
15701574
:type postprocess_config: dict, optional
1575+
:param run_async: If True run the prediction async, if False run sync. if omitted run synchronously.
1576+
:type run_async: bool
15711577
:return: Prediction response from REST API
15721578
:rtype: dict
15731579
@@ -1580,6 +1586,7 @@ def create_prediction(
15801586
'trainingId': training_id,
15811587
'preprocessConfig': preprocess_config,
15821588
'postprocessConfig': postprocess_config,
1589+
'async': run_async,
15831590
}
15841591
return self._make_request(requests.post, '/predictions', body=dictstrip(body))
15851592

@@ -1623,6 +1630,23 @@ def list_predictions(
16231630
}
16241631
return self._make_request(requests.get, '/predictions', params=dictstrip(params))
16251632

1633+
def get_prediction(self, prediction_id: str) -> Dict:
1634+
"""Get prediction, calls the GET /predictions/{predictionId} endpoint.
1635+
1636+
>>> from las.client import Client
1637+
>>> client = Client()
1638+
>>> client.get_prediction(prediction_id='<prediction id>')
1639+
1640+
:param prediction_id: Id of the prediction
1641+
:type prediction_id: str
1642+
:return: Asset response from REST API with content
1643+
:rtype: dict
1644+
1645+
:raises: :py:class:`~las.InvalidCredentialsException`, :py:class:`~las.TooManyRequestsException`,\
1646+
:py:class:`~las.LimitExceededException`, :py:class:`requests.exception.RequestException`
1647+
"""
1648+
return self._make_request(requests.get, f'/predictions/{prediction_id}')
1649+
16261650
def get_plan(self, plan_id: str) -> Dict:
16271651
"""Get information about a specific plan, calls the GET /plans/{plan_id} endpoint.
16281652

Diff for: tests/service.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def create_payment_method_id():
3434
return f'las:payment-method:{uuid4().hex}'
3535

3636

37+
def create_prediction_id():
38+
return f'las:prediction:{uuid4().hex}'
39+
40+
3741
def create_deployment_environment_id():
3842
return f'las:deployment-environment:{uuid4().hex}'
3943

Diff for: tests/test_predictions.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
{'strategy': 'BEST_N_PAGES', 'parameters': {'n': 3, 'collapse': False}},
2121
None,
2222
])
23-
def test_create_prediction(client: Client, preprocess_config, postprocess_config):
23+
@pytest.mark.parametrize('run_async', [True, False, None])
24+
def test_create_prediction(client: Client, preprocess_config, postprocess_config, run_async):
2425
document_id = service.create_document_id()
2526
model_id = service.create_model_id()
2627
response = client.create_prediction(
2728
document_id,
2829
model_id,
2930
preprocess_config=dictstrip(preprocess_config) if preprocess_config else None,
3031
postprocess_config=postprocess_config,
32+
run_async=run_async,
3133
)
3234
assert 'predictionId' in response, 'Missing predictionId in response'
3335

@@ -39,3 +41,12 @@ def test_list_predictions(client: Client, sort_by, order, model_id):
3941
response = client.list_predictions(sort_by=sort_by, order=order, model_id=model_id)
4042
logging.info(response)
4143
assert 'predictions' in response, 'Missing predictions in response'
44+
45+
46+
@pytest.mark.parametrize('prediction_id', [service.create_prediction_id(), None])
47+
def test_get_prediction(client: Client, prediction_id):
48+
response = client.get_prediction(prediction_id)
49+
logging.info(response)
50+
assert 'predictionId' in response, 'Missing prediction in response'
51+
assert 'inferenceTime' in response, 'Missing inferenceTime in response'
52+
assert 'predictions' in response, 'Missing predictions in response'

0 commit comments

Comments
 (0)