From dcd18e0051da840e3d7e9ea7c13c529331df3515 Mon Sep 17 00:00:00 2001 From: PProfizi Date: Fri, 15 Nov 2024 14:39:04 +0100 Subject: [PATCH 1/7] Make CollectionBase.__getitem__ support slices --- src/ansys/dpf/core/collection_base.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index 35c93a3e2b..eb2c8adc23 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -409,19 +409,36 @@ def get_label_scoping(self, label="time"): scoping = Scoping(self._api.collection_get_label_scoping(self, label), server=self._server) return scoping - def __getitem__(self, index): - """Retrieves the entry at a requested index value. + def __getitem__(self, index: int | slice): + """Retrieves the entry at a requested index value or build a new collection from a slice. Parameters ---------- - index : int + index: Index value. Returns ------- - entry : Field , Scoping - Entry at the index value. + entry: + Entry at the index value or new collection for entries at requested slice. """ + if isinstance(index, slice): + # handle slice + indices = list(range( + index.start if index.start else 0, + index.stop, + index.step if index.step else 1 + )) + out_collection = self.__class__() + out_collection.set_labels(labels=self._get_labels()) + entries = [ + out_collection._add_entry( + label_space=self.get_label_space(index=i), + entry=self._get_entries(label_space_or_index=i) + ) + for i in indices + ] + return out_collection self_len = len(self) if index < 0: # convert to a positive index From 370bc267305dfdbc1bdc24d1ad3de6b454dfb7df Mon Sep 17 00:00:00 2001 From: PProfizi Date: Fri, 15 Nov 2024 14:39:04 +0100 Subject: [PATCH 2/7] Make CollectionBase.__getitem__ support slices --- src/ansys/dpf/core/collection_base.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index 35c93a3e2b..0240fb2d4b 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -409,19 +409,36 @@ def get_label_scoping(self, label="time"): scoping = Scoping(self._api.collection_get_label_scoping(self, label), server=self._server) return scoping - def __getitem__(self, index): - """Retrieves the entry at a requested index value. + def __getitem__(self, index: int | slice): + """Retrieves the entry at a requested index value or build a new collection from a slice. Parameters ---------- - index : int + index: Index value. Returns ------- - entry : Field , Scoping - Entry at the index value. + entry: + Entry at the index value or new collection for entries at requested slice. """ + if isinstance(index, slice): + # handle slice + indices = list(range( + index.start if index.start else 0, + index.stop, + index.step if index.step else 1 + )) + out_collection = self.__class__() + out_collection.set_labels(labels=self._get_labels()) + [ + out_collection._add_entry( + label_space=self.get_label_space(index=i), + entry=self._get_entries(label_space_or_index=i) + ) + for i in indices + ] + return out_collection self_len = len(self) if index < 0: # convert to a positive index From 41420dd25c1f968afdab28eafc7e577ef68fa3c7 Mon Sep 17 00:00:00 2001 From: PProfizi Date: Fri, 15 Nov 2024 20:40:48 +0100 Subject: [PATCH 3/7] Improve typehint and autocomplete for Collections --- src/ansys/dpf/core/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/ansys/dpf/core/__init__.py b/src/ansys/dpf/core/__init__.py index 1c885d1d70..6b649261d0 100644 --- a/src/ansys/dpf/core/__init__.py +++ b/src/ansys/dpf/core/__init__.py @@ -115,10 +115,14 @@ # register classes for collection types: -CustomTypeFieldsCollection:type = _CollectionFactory(CustomTypeField) -GenericDataContainersCollection:type = _CollectionFactory(GenericDataContainer) -StringFieldsCollection:type = _CollectionFactory(StringField) -AnyCollection:type = _Collection +class CustomTypeFieldsCollection(_Collection[CustomTypeField]): + entries_type = CustomTypeField +class GenericDataContainersCollection(_Collection[GenericDataContainer]): + entries_type = GenericDataContainer +class StringFieldsCollection(_Collection[StringField]): + entries_type = StringField +class AnyCollection(_Collection[Any]): + entries_type = Any # for matplotlib # solves "QApplication: invalid style override passed, ignoring it." From b25519f21d08bfbd0aed90d84d95970bcdccb6bc Mon Sep 17 00:00:00 2001 From: PProfizi Date: Mon, 18 Nov 2024 10:27:24 +0100 Subject: [PATCH 4/7] Make compatible with FieldsContainers, ScopingsContainers and MeshesContainers. --- src/ansys/dpf/core/collection_base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index 0240fb2d4b..e660a92de0 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -431,8 +431,15 @@ def __getitem__(self, index: int | slice): )) out_collection = self.__class__() out_collection.set_labels(labels=self._get_labels()) + if hasattr(out_collection, "add_entry"): + # For any direct subclass of Collection + func = out_collection.add_entry + else: + # For FieldsContainers, ScopingsContainers and MeshesContainers + # because they have dedicated APIs + func = out_collection._add_entry [ - out_collection._add_entry( + func( label_space=self.get_label_space(index=i), entry=self._get_entries(label_space_or_index=i) ) From febd4c3a3613032fddb938c34ee37d8225056f1d Mon Sep 17 00:00:00 2001 From: PProfizi Date: Mon, 18 Nov 2024 10:27:38 +0100 Subject: [PATCH 5/7] Add tests --- tests/test_collection.py | 33 ++++++++++++++++++++++++++++++++- tests/test_fieldscontainer.py | 6 ++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/test_collection.py b/tests/test_collection.py index 304b007b8a..cb73753ca3 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -82,7 +82,7 @@ def create_dummy_gdc(server_type, prop="hi"): @dataclass class CollectionTypeHelper: type: type - instance_creator: object + instance_creator: callable kwargs: dict = field(default_factory=dict) @property @@ -244,3 +244,34 @@ def test_connect_collection_workflow(server_type, subtype_creator): out = op.get_output(0, subtype_creator.type) assert out is not None assert len(out) == 1 + +def test_generic_data_containers_collection_slice(server_type): + coll = GenericDataContainersCollection(server=server_type) + + coll.labels = ["id1", "id2"] + for i in range(5): + coll.add_entry( + label_space={"id1": i, "id2": 0}, + entry=create_dummy_gdc(server_type=server_type) + ) + assert len(coll) == 5 + print(coll) + sliced_coll = coll[:3] + assert len(sliced_coll) == 3 + print(sliced_coll) + + +def test_string_containers_collection_slice(server_type): + coll = StringFieldsCollection(server=server_type) + + coll.labels = ["id1", "id2"] + for i in range(5): + coll.add_entry( + label_space={"id1": i, "id2": 0}, + entry=create_dummy_string_field(server_type=server_type) + ) + assert len(coll) == 5 + print(coll) + sliced_coll = coll[:3] + assert len(sliced_coll) == 3 + print(sliced_coll) diff --git a/tests/test_fieldscontainer.py b/tests/test_fieldscontainer.py index 42b643bb3b..10c6961d2b 100644 --- a/tests/test_fieldscontainer.py +++ b/tests/test_fieldscontainer.py @@ -598,3 +598,9 @@ def test_get_entries_indices_fields_container(server_type): assert np.allclose(fc.get_entries_indices({"time": 1, "complex": 0}), [0]) assert np.allclose(fc.get_entries_indices({"time": 2}), [1]) assert np.allclose(fc.get_entries_indices({"complex": 0}), range(0, 20)) + + +def test_fields_container_slice(server_type, disp_fc): + print(disp_fc) + fc = disp_fc[:1] + assert len(fc) == 1 From 6224cb354f53164da0aacee4ce3ff72310bf303f Mon Sep 17 00:00:00 2001 From: PProfizi Date: Tue, 19 Nov 2024 14:58:13 +0100 Subject: [PATCH 6/7] Fix Code Quality check --- src/ansys/dpf/core/collection_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index e660a92de0..c98cd93cf9 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -438,13 +438,11 @@ def __getitem__(self, index: int | slice): # For FieldsContainers, ScopingsContainers and MeshesContainers # because they have dedicated APIs func = out_collection._add_entry - [ + for i in indices: func( label_space=self.get_label_space(index=i), entry=self._get_entries(label_space_or_index=i) ) - for i in indices - ] return out_collection self_len = len(self) if index < 0: From d8c7b1e1e4f1cf20fdf09ee5a6e487895d9c260e Mon Sep 17 00:00:00 2001 From: PProfizi Date: Tue, 19 Nov 2024 15:05:52 +0100 Subject: [PATCH 7/7] Fix Code Quality check --- src/ansys/dpf/core/collection_base.py | 12 ++++++------ tests/test_collection.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index c98cd93cf9..18bc3219ed 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -424,11 +424,11 @@ def __getitem__(self, index: int | slice): """ if isinstance(index, slice): # handle slice - indices = list(range( - index.start if index.start else 0, - index.stop, - index.step if index.step else 1 - )) + indices = list( + range( + index.start if index.start else 0, index.stop, index.step if index.step else 1 + ) + ) out_collection = self.__class__() out_collection.set_labels(labels=self._get_labels()) if hasattr(out_collection, "add_entry"): @@ -441,7 +441,7 @@ def __getitem__(self, index: int | slice): for i in indices: func( label_space=self.get_label_space(index=i), - entry=self._get_entries(label_space_or_index=i) + entry=self._get_entries(label_space_or_index=i), ) return out_collection self_len = len(self) diff --git a/tests/test_collection.py b/tests/test_collection.py index cb73753ca3..017ec3211a 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -245,14 +245,14 @@ def test_connect_collection_workflow(server_type, subtype_creator): assert out is not None assert len(out) == 1 + def test_generic_data_containers_collection_slice(server_type): coll = GenericDataContainersCollection(server=server_type) coll.labels = ["id1", "id2"] for i in range(5): coll.add_entry( - label_space={"id1": i, "id2": 0}, - entry=create_dummy_gdc(server_type=server_type) + label_space={"id1": i, "id2": 0}, entry=create_dummy_gdc(server_type=server_type) ) assert len(coll) == 5 print(coll) @@ -268,7 +268,7 @@ def test_string_containers_collection_slice(server_type): for i in range(5): coll.add_entry( label_space={"id1": i, "id2": 0}, - entry=create_dummy_string_field(server_type=server_type) + entry=create_dummy_string_field(server_type=server_type), ) assert len(coll) == 5 print(coll)