diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index a3207ca50..1b052a2ba 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.17.0" +__version__ = "3.18.0" diff --git a/src/citrine/resources/material_run.py b/src/citrine/resources/material_run.py index e57de4569..2c404587f 100644 --- a/src/citrine/resources/material_run.py +++ b/src/citrine/resources/material_run.py @@ -7,6 +7,7 @@ from citrine._serialization.properties import String, LinkOrElse from citrine._utils.functions import format_escaped_url from citrine.resources.data_concepts import _make_link_by_uid +from citrine.resources.data_concepts import CITRINE_TAG_PREFIX from citrine.resources.material_spec import MaterialSpecCollection from citrine.resources.object_runs import ObjectRun, ObjectRunCollection from gemd.entity.file_link import FileLink @@ -49,6 +50,16 @@ class MaterialRun( The material specification of which this is an instance. file_links: List[FileLink], optional Links to associated files, with resource paths into the files API. + default_labels: List[str], optional + An optional set of default labels to apply to this material run. + Default labels are used to: + - Populate labels on the ingredient run, if none are explicitly + specified, when the material run is later used as an ingredient + - Marking the material run as a potential replacement ingredient for a + particular label when generating new candidates using a + design space. Note that during design, default labels are only applicable + if the material run has no associated ingredient run within the + training set in question. """ @@ -74,12 +85,14 @@ def __init__(self, process: Optional[GEMDProcessRun] = None, sample_type: Optional[str] = "unknown", spec: Optional[GEMDMaterialSpec] = None, - file_links: Optional[List[FileLink]] = None): + file_links: Optional[List[FileLink]] = None, + default_labels: Optional[List[str]] = None): if uids is None: uids = dict() + all_tags = _inject_default_label_tags(tags, default_labels) super(ObjectRun, self).__init__() GEMDMaterialRun.__init__(self, name=name, uids=uids, - tags=tags, process=process, + tags=all_tags, process=process, sample_type=sample_type, spec=spec, file_links=file_links, notes=notes) @@ -216,3 +229,22 @@ def list_by_template(self, specs = spec_collection.list_by_template(uid=_make_link_by_uid(uid)) return (run for runs in (self.list_by_spec(spec) for spec in specs) for run in runs) + + +_CITRINE_DEFAULT_LABEL_PREFIX = f'{CITRINE_TAG_PREFIX}::mat_label' + + +def _inject_default_label_tags( + original_tags: Optional[List[str]], default_labels: Optional[List[str]] +) -> Optional[List[str]]: + if default_labels is None: + all_tags = original_tags + else: + labels_as_tags = [ + f"{_CITRINE_DEFAULT_LABEL_PREFIX}::{label}" for label in default_labels + ] + if original_tags is None: + all_tags = labels_as_tags + else: + all_tags = list(original_tags) + labels_as_tags + return all_tags diff --git a/tests/resources/test_material_run.py b/tests/resources/test_material_run.py index 00b05decd..3921b3bf1 100644 --- a/tests/resources/test_material_run.py +++ b/tests/resources/test_material_run.py @@ -9,6 +9,7 @@ from citrine.resources.data_concepts import CITRINE_SCOPE from citrine.resources.material_run import MaterialRunCollection from citrine.resources.material_run import MaterialRun as CitrineRun +from citrine.resources.material_run import _inject_default_label_tags from citrine.resources.gemd_resource import GEMDResourceCollection from gemd.demo.cake import make_cake, change_scope @@ -53,6 +54,44 @@ def test_invalid_collection_construction(): mr = MaterialRunCollection(dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), session=session) + +@pytest.mark.parametrize( + "original_tags, default_labels, expected", + [ + (None, None, None), + (None, [], []), + ([], None, []), + ([], [], []), + ( + None, + ["label 0", "label 1"], + ["citr_auto::mat_label::label 0", "citr_auto::mat_label::label 1"], + ), + ( + [], + ["label 0", "label 1"], + ["citr_auto::mat_label::label 0", "citr_auto::mat_label::label 1"], + ), + (["alpha", "beta", "gamma"], None, ["alpha", "beta", "gamma"]), + (["alpha", "beta", "gamma"], [], ["alpha", "beta", "gamma"]), + ( + ["alpha", "beta", "gamma"], + ["label 0", "label 1"], + [ + "alpha", + "beta", + "gamma", + "citr_auto::mat_label::label 0", + "citr_auto::mat_label::label 1", + ], + ), + ], +) +def test_inject_default_label_tags(original_tags, default_labels, expected): + result = _inject_default_label_tags(original_tags, default_labels) + assert result == expected + + def test_register_material_run(collection, session): # Given session.set_response(MaterialRunDataFactory(name='Test MR 123'))