From 4617536746dd5a62499f1ab9fd8ddc1b39fee71a Mon Sep 17 00:00:00 2001 From: Jamie Heller Date: Tue, 25 Mar 2025 11:06:26 -0400 Subject: [PATCH 1/3] PNE-6482 Support `default_labels` in MaterialRun constructor --- src/citrine/__version__.py | 2 +- src/citrine/resources/material_run.py | 32 +++++++++++++++++++++++++-- tests/resources/test_material_run.py | 15 +++++++++++++ tests/utils/factories.py | 3 +++ 4 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index a3207ca50..1b052a2ba 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.17.0" +__version__ = "3.18.0" diff --git a/src/citrine/resources/material_run.py b/src/citrine/resources/material_run.py index e57de4569..1109e3891 100644 --- a/src/citrine/resources/material_run.py +++ b/src/citrine/resources/material_run.py @@ -7,6 +7,7 @@ from citrine._serialization.properties import String, LinkOrElse from citrine._utils.functions import format_escaped_url from citrine.resources.data_concepts import _make_link_by_uid +from citrine.resources.data_concepts import CITRINE_TAG_PREFIX from citrine.resources.material_spec import MaterialSpecCollection from citrine.resources.object_runs import ObjectRun, ObjectRunCollection from gemd.entity.file_link import FileLink @@ -49,6 +50,16 @@ class MaterialRun( The material specification of which this is an instance. file_links: List[FileLink], optional Links to associated files, with resource paths into the files API. + default_labels: List[str], optional + An optional set of default labels to apply to this material run. + Default labels are used to: + - Populate labels on the ingredient run, if none are explicitly + specified, when the material run is later used as an ingredient + - Marking the material run as a potential replacement ingredient for a + particular label when generating new candidates using a + design space. Note that during design, default labels are only applicable + if the material run has no associated ingredient run within the + training set in question. """ @@ -74,12 +85,14 @@ def __init__(self, process: Optional[GEMDProcessRun] = None, sample_type: Optional[str] = "unknown", spec: Optional[GEMDMaterialSpec] = None, - file_links: Optional[List[FileLink]] = None): + file_links: Optional[List[FileLink]] = None, + default_labels: Optional[List[str]] = None): if uids is None: uids = dict() + all_tags = _inject_default_label_tags(tags, default_labels) super(ObjectRun, self).__init__() GEMDMaterialRun.__init__(self, name=name, uids=uids, - tags=tags, process=process, + tags=all_tags, process=process, sample_type=sample_type, spec=spec, file_links=file_links, notes=notes) @@ -216,3 +229,18 @@ def list_by_template(self, specs = spec_collection.list_by_template(uid=_make_link_by_uid(uid)) return (run for runs in (self.list_by_spec(spec) for spec in specs) for run in runs) + + +_CITRINE_DEFAULT_LABEL_PREFIX = f'{CITRINE_TAG_PREFIX}::mat_label' + + +def _inject_default_label_tags( + original_tags: Optional[List[str]], + default_labels: Optional[List[str]]) -> List[str]: + all_tags: List[str] = [] + if original_tags is not None: + all_tags.extend(original_tags) + if default_labels is not None: + all_tags.extend([f"{_CITRINE_DEFAULT_LABEL_PREFIX}::{label}" + for label in default_labels]) + return all_tags diff --git a/tests/resources/test_material_run.py b/tests/resources/test_material_run.py index 00b05decd..257007b6d 100644 --- a/tests/resources/test_material_run.py +++ b/tests/resources/test_material_run.py @@ -9,6 +9,7 @@ from citrine.resources.data_concepts import CITRINE_SCOPE from citrine.resources.material_run import MaterialRunCollection from citrine.resources.material_run import MaterialRun as CitrineRun +from citrine.resources.material_run import _inject_default_label_tags from citrine.resources.gemd_resource import GEMDResourceCollection from gemd.demo.cake import make_cake, change_scope @@ -53,6 +54,20 @@ def test_invalid_collection_construction(): mr = MaterialRunCollection(dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), session=session) +def test_inject_default_label_tags(): + original_tags = ["alpha", "beta", "gamma"] + default_labels = ["label 0", "label 1"] + all_tags = _inject_default_label_tags(original_tags, default_labels) + expected = [ + "alpha", + "beta", + "gamma", + "citr_auto::mat_label::label 0", + "citr_auto::mat_label::label 1" + ] + assert set(all_tags) == set(expected) + + def test_register_material_run(collection, session): # Given session.set_response(MaterialRunDataFactory(name='Test MR 123')) diff --git a/tests/utils/factories.py b/tests/utils/factories.py index d6d5c989a..d1a4c5d3d 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -666,6 +666,9 @@ class Meta: sample_type = factory.Faker("enum", enum_cls=SampleType) spec = factory.SubFactory(LinkByUIDFactory) file_links = factory.List([factory.SubFactory(FileLinkFactory)]) + default_labels = factory.List( + [factory.Faker("color_name"), factory.Faker("color_name")] + ) class LinkByUIDDataFactory(factory.DictFactory): From 8aff25a2f56cb65e913733fc09258ffd3fe820bf Mon Sep 17 00:00:00 2001 From: Jamie Heller Date: Wed, 26 Mar 2025 13:26:40 -0400 Subject: [PATCH 2/3] PNE-6482 Ensure `None`s are propagated as expected, expand test --- src/citrine/resources/material_run.py | 20 ++++++----- tests/resources/test_material_run.py | 48 ++++++++++++++++++++------- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/src/citrine/resources/material_run.py b/src/citrine/resources/material_run.py index 1109e3891..2c404587f 100644 --- a/src/citrine/resources/material_run.py +++ b/src/citrine/resources/material_run.py @@ -235,12 +235,16 @@ def list_by_template(self, def _inject_default_label_tags( - original_tags: Optional[List[str]], - default_labels: Optional[List[str]]) -> List[str]: - all_tags: List[str] = [] - if original_tags is not None: - all_tags.extend(original_tags) - if default_labels is not None: - all_tags.extend([f"{_CITRINE_DEFAULT_LABEL_PREFIX}::{label}" - for label in default_labels]) + original_tags: Optional[List[str]], default_labels: Optional[List[str]] +) -> Optional[List[str]]: + if default_labels is None: + all_tags = original_tags + else: + labels_as_tags = [ + f"{_CITRINE_DEFAULT_LABEL_PREFIX}::{label}" for label in default_labels + ] + if original_tags is None: + all_tags = labels_as_tags + else: + all_tags = list(original_tags) + labels_as_tags return all_tags diff --git a/tests/resources/test_material_run.py b/tests/resources/test_material_run.py index 257007b6d..3921b3bf1 100644 --- a/tests/resources/test_material_run.py +++ b/tests/resources/test_material_run.py @@ -54,18 +54,42 @@ def test_invalid_collection_construction(): mr = MaterialRunCollection(dataset_id=UUID('8da51e93-8b55-4dd3-8489-af8f65d4ad9a'), session=session) -def test_inject_default_label_tags(): - original_tags = ["alpha", "beta", "gamma"] - default_labels = ["label 0", "label 1"] - all_tags = _inject_default_label_tags(original_tags, default_labels) - expected = [ - "alpha", - "beta", - "gamma", - "citr_auto::mat_label::label 0", - "citr_auto::mat_label::label 1" - ] - assert set(all_tags) == set(expected) + +@pytest.mark.parametrize( + "original_tags, default_labels, expected", + [ + (None, None, None), + (None, [], []), + ([], None, []), + ([], [], []), + ( + None, + ["label 0", "label 1"], + ["citr_auto::mat_label::label 0", "citr_auto::mat_label::label 1"], + ), + ( + [], + ["label 0", "label 1"], + ["citr_auto::mat_label::label 0", "citr_auto::mat_label::label 1"], + ), + (["alpha", "beta", "gamma"], None, ["alpha", "beta", "gamma"]), + (["alpha", "beta", "gamma"], [], ["alpha", "beta", "gamma"]), + ( + ["alpha", "beta", "gamma"], + ["label 0", "label 1"], + [ + "alpha", + "beta", + "gamma", + "citr_auto::mat_label::label 0", + "citr_auto::mat_label::label 1", + ], + ), + ], +) +def test_inject_default_label_tags(original_tags, default_labels, expected): + result = _inject_default_label_tags(original_tags, default_labels) + assert result == expected def test_register_material_run(collection, session): From 4d9caff1658696d8f5a1e6585873f2a3434f2692 Mon Sep 17 00:00:00 2001 From: Jamie Heller Date: Wed, 26 Mar 2025 13:28:24 -0400 Subject: [PATCH 3/3] PNE-6482 Remove default_labels from MaterialRunFactory --- tests/utils/factories.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/utils/factories.py b/tests/utils/factories.py index d1a4c5d3d..d6d5c989a 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -666,9 +666,6 @@ class Meta: sample_type = factory.Faker("enum", enum_cls=SampleType) spec = factory.SubFactory(LinkByUIDFactory) file_links = factory.List([factory.SubFactory(FileLinkFactory)]) - default_labels = factory.List( - [factory.Faker("color_name"), factory.Faker("color_name")] - ) class LinkByUIDDataFactory(factory.DictFactory):