Skip to content

Commit 179cec9

Browse files
feat: DIA-1685: [sdk] Create example predictions and annotations from a LabelConfig (#360)
1 parent 8093e3d commit 179cec9

File tree

4 files changed

+215
-16
lines changed

4 files changed

+215
-16
lines changed

poetry.lock

+150-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ typing_extensions = ">= 4.0.0"
4949
ujson = ">=5.8.0"
5050
xmljson = "0.2.1"
5151

52+
jsf = "^0.11.2"
5253
[tool.poetry.dev-dependencies]
5354
mypy = "1.0.1"
5455
pytest = "^7.4.0"

src/label_studio_sdk/label_interface/interface.py

+63-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from collections import defaultdict, OrderedDict
1616
from lxml import etree
1717
import xmljson
18+
from jsf import JSF
1819

1920
from label_studio_sdk._legacy.exceptions import (
2021
LSConfigParseException,
@@ -770,7 +771,7 @@ def validate_region(self, region) -> bool:
770771
return False
771772

772773
# type of the region should match the tag name
773-
if control.tag.lower() != region["type"]:
774+
if control.tag.lower() != region["type"].lower():
774775
return False
775776

776777
# make sure that in config it connects to the same tag as
@@ -839,9 +840,67 @@ def generate_sample_task(self, mode="upload", secure_mode=False):
839840

840841
return task
841842

842-
def generate_sample_annotation(self):
843-
""" """
844-
raise NotImplemented()
843+
def _generate_sample_regions(self):
844+
""" Generate an example of each control tag's JSON schema and validate it as a region"""
845+
return self.create_regions({
846+
control.name: JSF(control.to_json_schema()).generate()
847+
for control in self.controls
848+
})
849+
850+
def generate_sample_prediction(self) -> Optional[dict]:
851+
"""Generates a sample prediction that is valid for this label config.
852+
853+
Example:
854+
{'model_version': 'sample model version',
855+
'score': 0.0,
856+
'result': [{'id': 'e7bd76e6-4e88-4eb3-b433-55e03661bf5d',
857+
'from_name': 'sentiment',
858+
'to_name': 'text',
859+
'type': 'choices',
860+
'value': {'choices': ['Neutral']}}]}
861+
862+
NOTE: `id` field in result is not required when importing predictions; it will be generated automatically.
863+
NOTE: for each control tag, depends on tag.to_json_schema() being implemented correctly
864+
"""
865+
prediction = PredictionValue(
866+
model_version='sample model version',
867+
result=self._generate_sample_regions()
868+
)
869+
prediction_dct = prediction.model_dump()
870+
if self.validate_prediction(prediction_dct):
871+
return prediction_dct
872+
else:
873+
logger.debug(f'Sample prediction {prediction_dct} failed validation for label config {self.config}')
874+
return None
875+
876+
def generate_sample_annotation(self) -> Optional[dict]:
877+
"""Generates a sample annotation that is valid for this label config.
878+
879+
Example:
880+
{'was_cancelled': False,
881+
'ground_truth': False,
882+
'lead_time': 0.0,
883+
'result_count': 0,
884+
'completed_by': -1,
885+
'result': [{'id': 'b05da11d-3ffc-4657-8b8d-f5bc37cd59ac',
886+
'from_name': 'sentiment',
887+
'to_name': 'text',
888+
'type': 'choices',
889+
'value': {'choices': ['Negative']}}]}
890+
891+
NOTE: `id` field in result is not required when importing predictions; it will be generated automatically.
892+
NOTE: for each control tag, depends on tag.to_json_schema() being implemented correctly
893+
"""
894+
annotation = AnnotationValue(
895+
completed_by=-1, # annotator's user id
896+
result=self._generate_sample_regions()
897+
)
898+
annotation_dct = annotation.model_dump()
899+
if self.validate_annotation(annotation_dct):
900+
return annotation_dct
901+
else:
902+
logger.debug(f'Sample annotation {annotation_dct} failed validation for label config {self.config}')
903+
return None
845904

846905
#####
847906
##### COMPATIBILITY LAYER

src/label_studio_sdk/label_interface/region.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,7 @@
1010

1111
class Region(BaseModel):
1212
"""
13-
Class for Region Tag
14-
15-
Attributes:
16-
-----------
17-
id: str
18-
The unique identifier of the region
19-
x: int
20-
The x coordinate of the region
21-
y: int
22-
13+
A Region is an item in the `result` list of a PredictionValue or AnnotationValue.
2314
"""
2415

2516
id: str = Field(default_factory=lambda: str(uuid4()))

0 commit comments

Comments
 (0)