diff --git a/src/ansys/dpf/core/__init__.py b/src/ansys/dpf/core/__init__.py index b33f850a3a..81b911d12f 100644 --- a/src/ansys/dpf/core/__init__.py +++ b/src/ansys/dpf/core/__init__.py @@ -114,11 +114,16 @@ # register classes for collection types: -CustomTypeFieldsCollection:type = _CollectionFactory(CustomTypeField) -GenericDataContainersCollection:type = _CollectionFactory(GenericDataContainer) -StringFieldsCollection:type = _CollectionFactory(StringField) -OperatorsCollection: type = _CollectionFactory(Operator) -AnyCollection:type = _Collection +class CustomTypeFieldsCollection(_Collection[CustomTypeField]): + entries_type = CustomTypeField +class GenericDataContainersCollection(_Collection[GenericDataContainer]): + entries_type = GenericDataContainer +class StringFieldsCollection(_Collection[StringField]): + entries_type = StringField +class OperatorsCollection(_Collection[Operator]): + entries_type = Operator +class AnyCollection(_Collection[Any]): + entries_type = Any # for matplotlib # solves "QApplication: invalid style override passed, ignoring it." diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index 20f32e2de5..c1fd95ca30 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -414,19 +414,41 @@ def get_label_scoping(self, label="time"): scoping = Scoping(self._api.collection_get_label_scoping(self, label), server=self._server) return scoping - def __getitem__(self, index): - """Retrieve the entry at a requested index value. + def __getitem__(self, index: int | slice): + """Retrieve the entry at a requested index value or build a new collection from a slice. Parameters ---------- - index : int + index: Index value. Returns ------- - entry : Field , Scoping - Entry at the index value. + entry: + Entry at the index value or new collection for entries at requested slice. """ + if isinstance(index, slice): + # handle slice + indices = list( + range( + index.start if index.start else 0, index.stop, index.step if index.step else 1 + ) + ) + out_collection = self.__class__() + out_collection.set_labels(labels=self._get_labels()) + if hasattr(out_collection, "add_entry"): + # For any direct subclass of Collection + func = out_collection.add_entry + else: + # For FieldsContainers, ScopingsContainers and MeshesContainers + # because they have dedicated APIs + func = out_collection._add_entry + for i in indices: + func( + label_space=self.get_label_space(index=i), + entry=self._get_entries(label_space_or_index=i), + ) + return out_collection self_len = len(self) if index < 0: # convert to a positive index diff --git a/tests/test_collection.py b/tests/test_collection.py index 799d567321..61413ef5b4 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -83,7 +83,7 @@ def create_dummy_gdc(server_type, prop="hi"): @dataclass class CollectionTypeHelper: type: type - instance_creator: object + instance_creator: callable kwargs: dict = field(default_factory=dict) @property @@ -245,3 +245,34 @@ def test_connect_collection_workflow(server_type, subtype_creator): out = op.get_output(0, subtype_creator.type) assert out is not None assert len(out) == 1 + + +def test_generic_data_containers_collection_slice(server_type): + coll = GenericDataContainersCollection(server=server_type) + + coll.labels = ["id1", "id2"] + for i in range(5): + coll.add_entry( + label_space={"id1": i, "id2": 0}, entry=create_dummy_gdc(server_type=server_type) + ) + assert len(coll) == 5 + print(coll) + sliced_coll = coll[:3] + assert len(sliced_coll) == 3 + print(sliced_coll) + + +def test_string_containers_collection_slice(server_type): + coll = StringFieldsCollection(server=server_type) + + coll.labels = ["id1", "id2"] + for i in range(5): + coll.add_entry( + label_space={"id1": i, "id2": 0}, + entry=create_dummy_string_field(server_type=server_type), + ) + assert len(coll) == 5 + print(coll) + sliced_coll = coll[:3] + assert len(sliced_coll) == 3 + print(sliced_coll) diff --git a/tests/test_fieldscontainer.py b/tests/test_fieldscontainer.py index 325fdb4da3..3dbe977a08 100644 --- a/tests/test_fieldscontainer.py +++ b/tests/test_fieldscontainer.py @@ -602,3 +602,9 @@ def test_get_entries_indices_fields_container(server_type): assert np.allclose(fc.get_entries_indices({"time": 1, "complex": 0}), [0]) assert np.allclose(fc.get_entries_indices({"time": 2}), [1]) assert np.allclose(fc.get_entries_indices({"complex": 0}), range(0, 20)) + + +def test_fields_container_slice(server_type, disp_fc): + print(disp_fc) + fc = disp_fc[:1] + assert len(fc) == 1