Skip to content

Commit 1fc6212

Browse files
rlyoruebel
andauthored
Fix resolution of extension classes that have references (#1183)
* Fix resolution of extension classes that have references * Update changelog * Remove unnecessary if * Update CHANGELOG.md Co-authored-by: Oliver Ruebel <[email protected]> --------- Co-authored-by: Oliver Ruebel <[email protected]>
1 parent d378dec commit 1fc6212

File tree

4 files changed

+196
-5
lines changed

4 files changed

+196
-5
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
### Bug fixes
1313
- Fixed issue where scalar datasets with a compound data type were being written as non-scalar datasets @stephprince [#1176](https://github.com/hdmf-dev/hdmf/pull/1176)
1414
- Fixed H5DataIO not exposing `maxshape` on non-dci dsets. @cboulay [#1149](https://github.com/hdmf-dev/hdmf/pull/1149)
15+
- Fixed generation of classes in an extension that contain attributes or datasets storing references to other types defined in the extension.
16+
@rly [#1183](https://github.com/hdmf-dev/hdmf/pull/1183)
1517

1618
## HDMF 3.14.3 (July 29, 2024)
1719

src/hdmf/build/manager.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator
88
from ..container import AbstractContainer, Container, Data
99
from ..term_set import TypeConfigurator
10-
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog
10+
from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog, RefSpec
1111
from ..spec.spec import BaseStorageSpec
1212
from ..utils import docval, getargs, ExtenderMeta, get_docval
1313

@@ -480,6 +480,7 @@ def load_namespaces(self, **kwargs):
480480
load_namespaces here has the advantage of being able to keep track of type dependencies across namespaces.
481481
'''
482482
deps = self.__ns_catalog.load_namespaces(**kwargs)
483+
# register container types for each dependent type in each dependent namespace
483484
for new_ns, ns_deps in deps.items():
484485
for src_ns, types in ns_deps.items():
485486
for dt in types:
@@ -529,7 +530,7 @@ def get_dt_container_cls(self, **kwargs):
529530
namespace = ns_key
530531
break
531532
if namespace is None:
532-
raise ValueError("Namespace could not be resolved.")
533+
raise ValueError(f"Namespace could not be resolved for data type '{data_type}'.")
533534

534535
cls = self.__get_container_cls(namespace, data_type)
535536

@@ -549,6 +550,8 @@ def get_dt_container_cls(self, **kwargs):
549550

550551
def __check_dependent_types(self, spec, namespace):
551552
"""Ensure that classes for all types used by this type exist in this namespace and generate them if not.
553+
554+
`spec` should be a GroupSpec or DatasetSpec in the `namespace`
552555
"""
553556
def __check_dependent_types_helper(spec, namespace):
554557
if isinstance(spec, (GroupSpec, DatasetSpec)):
@@ -564,6 +567,16 @@ def __check_dependent_types_helper(spec, namespace):
564567

565568
if spec.data_type_inc is not None:
566569
self.get_dt_container_cls(spec.data_type_inc, namespace)
570+
571+
# handle attributes that have a reference dtype
572+
for attr_spec in spec.attributes:
573+
if isinstance(attr_spec.dtype, RefSpec):
574+
self.get_dt_container_cls(attr_spec.dtype.target_type, namespace)
575+
# handle datasets that have a reference dtype
576+
if isinstance(spec, DatasetSpec):
577+
if isinstance(spec.dtype, RefSpec):
578+
self.get_dt_container_cls(spec.dtype.target_type, namespace)
579+
# recurse into nested types
567580
if isinstance(spec, GroupSpec):
568581
for child_spec in (spec.groups + spec.datasets + spec.links):
569582
__check_dependent_types_helper(child_spec, namespace)

tests/unit/build_tests/test_classgenerator.py

Lines changed: 178 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from hdmf.build import TypeMap, CustomClassGenerator
88
from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator
99
from hdmf.container import Container, Data, MultiContainerInterface, AbstractContainer
10-
from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec
10+
from hdmf.spec import (
11+
GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec, RefSpec
12+
)
1113
from hdmf.testing import TestCase
1214
from hdmf.utils import get_docval, docval
1315

@@ -734,9 +736,18 @@ def _build_separate_namespaces(self):
734736
GroupSpec(data_type_inc='Bar', doc='a bar', quantity='?')
735737
]
736738
)
739+
moo_spec = DatasetSpec(
740+
doc='A test dataset that is a 1D array of object references of Baz',
741+
data_type_def='Moo',
742+
shape=(None,),
743+
dtype=RefSpec(
744+
reftype='object',
745+
target_type='Baz'
746+
)
747+
)
737748
create_load_namespace_yaml(
738749
namespace_name='ndx-test',
739-
specs=[baz_spec],
750+
specs=[baz_spec, moo_spec],
740751
output_dir=self.test_dir,
741752
incl_types={
742753
CORE_NAMESPACE: ['Bar'],
@@ -828,6 +839,171 @@ def test_get_class_include_from_separate_ns_4(self):
828839

829840
self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2)
830841

842+
class TestGetClassObjectReferences(TestCase):
843+
844+
def setUp(self):
845+
self.test_dir = tempfile.mkdtemp()
846+
if os.path.exists(self.test_dir): # start clean
847+
self.tearDown()
848+
os.mkdir(self.test_dir)
849+
self.type_map = TypeMap()
850+
851+
def tearDown(self):
852+
shutil.rmtree(self.test_dir)
853+
854+
def test_get_class_include_dataset_of_references(self):
855+
"""Test that get_class resolves datasets of object references."""
856+
qux_spec = DatasetSpec(
857+
doc='A test extension',
858+
data_type_def='Qux'
859+
)
860+
moo_spec = DatasetSpec(
861+
doc='A test dataset that is a 1D array of object references of Qux',
862+
data_type_def='Moo',
863+
shape=(None,),
864+
dtype=RefSpec(
865+
reftype='object',
866+
target_type='Qux'
867+
),
868+
)
869+
870+
create_load_namespace_yaml(
871+
namespace_name='ndx-test',
872+
specs=[qux_spec, moo_spec],
873+
output_dir=self.test_dir,
874+
incl_types={},
875+
type_map=self.type_map
876+
)
877+
# no types should be resolved to start
878+
assert self.type_map.get_container_classes('ndx-test') == []
879+
880+
self.type_map.get_dt_container_cls('Moo', 'ndx-test')
881+
# now, Moo and Qux should be resolved
882+
assert len(self.type_map.get_container_classes('ndx-test')) == 2
883+
assert "Moo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
884+
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
885+
886+
def test_get_class_include_attribute_object_reference(self):
887+
"""Test that get_class resolves data types with an attribute that is an object reference."""
888+
qux_spec = DatasetSpec(
889+
doc='A test extension',
890+
data_type_def='Qux'
891+
)
892+
woo_spec = DatasetSpec(
893+
doc='A test dataset that has a scalar object reference to a Qux',
894+
data_type_def='Woo',
895+
attributes=[
896+
AttributeSpec(
897+
name='attr1',
898+
doc='a string attribute',
899+
dtype=RefSpec(reftype='object', target_type='Qux')
900+
),
901+
]
902+
)
903+
create_load_namespace_yaml(
904+
namespace_name='ndx-test',
905+
specs=[qux_spec, woo_spec],
906+
output_dir=self.test_dir,
907+
incl_types={},
908+
type_map=self.type_map
909+
)
910+
# no types should be resolved to start
911+
assert self.type_map.get_container_classes('ndx-test') == []
912+
913+
self.type_map.get_dt_container_cls('Woo', 'ndx-test')
914+
# now, Woo and Qux should be resolved
915+
assert len(self.type_map.get_container_classes('ndx-test')) == 2
916+
assert "Woo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
917+
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
918+
919+
def test_get_class_include_nested_object_reference(self):
920+
"""Test that get_class resolves nested datasets that are object references."""
921+
qux_spec = DatasetSpec(
922+
doc='A test extension',
923+
data_type_def='Qux'
924+
)
925+
spam_spec = DatasetSpec(
926+
doc='A test extension',
927+
data_type_def='Spam',
928+
shape=(None,),
929+
dtype=RefSpec(
930+
reftype='object',
931+
target_type='Qux'
932+
),
933+
)
934+
goo_spec = GroupSpec(
935+
doc='A test dataset that has a nested dataset (Spam) that has a scalar object reference to a Qux',
936+
data_type_def='Goo',
937+
datasets=[
938+
DatasetSpec(
939+
doc='a dataset',
940+
data_type_inc='Spam',
941+
),
942+
],
943+
)
944+
945+
create_load_namespace_yaml(
946+
namespace_name='ndx-test',
947+
specs=[qux_spec, spam_spec, goo_spec],
948+
output_dir=self.test_dir,
949+
incl_types={},
950+
type_map=self.type_map
951+
)
952+
# no types should be resolved to start
953+
assert self.type_map.get_container_classes('ndx-test') == []
954+
955+
self.type_map.get_dt_container_cls('Goo', 'ndx-test')
956+
# now, Goo, Spam, and Qux should be resolved
957+
assert len(self.type_map.get_container_classes('ndx-test')) == 3
958+
assert "Goo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
959+
assert "Spam" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
960+
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
961+
962+
def test_get_class_include_nested_attribute_object_reference(self):
963+
"""Test that get_class resolves nested datasets that have an attribute that is an object reference."""
964+
qux_spec = DatasetSpec(
965+
doc='A test extension',
966+
data_type_def='Qux'
967+
)
968+
bam_spec = DatasetSpec(
969+
doc='A test extension',
970+
data_type_def='Bam',
971+
attributes=[
972+
AttributeSpec(
973+
name='attr1',
974+
doc='a string attribute',
975+
dtype=RefSpec(reftype='object', target_type='Qux')
976+
),
977+
],
978+
)
979+
boo_spec = GroupSpec(
980+
doc='A test dataset that has a nested dataset (Spam) that has a scalar object reference to a Qux',
981+
data_type_def='Boo',
982+
datasets=[
983+
DatasetSpec(
984+
doc='a dataset',
985+
data_type_inc='Bam',
986+
),
987+
],
988+
)
989+
990+
create_load_namespace_yaml(
991+
namespace_name='ndx-test',
992+
specs=[qux_spec, bam_spec, boo_spec],
993+
output_dir=self.test_dir,
994+
incl_types={},
995+
type_map=self.type_map
996+
)
997+
# no types should be resolved to start
998+
assert self.type_map.get_container_classes('ndx-test') == []
999+
1000+
self.type_map.get_dt_container_cls('Boo', 'ndx-test')
1001+
# now, Boo, Bam, and Qux should be resolved
1002+
assert len(self.type_map.get_container_classes('ndx-test')) == 3
1003+
assert "Boo" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
1004+
assert "Bam" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
1005+
assert "Qux" in [c.__name__ for c in self.type_map.get_container_classes('ndx-test')]
1006+
8311007

8321008
class EmptyBar(Container):
8331009
pass

tests/unit/build_tests/test_io_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def test_get_dt_container_cls(self):
341341
self.assertIs(ret, Foo)
342342

343343
def test_get_dt_container_cls_no_namespace(self):
344-
with self.assertRaisesWith(ValueError, "Namespace could not be resolved."):
344+
with self.assertRaisesWith(ValueError, "Namespace could not be resolved for data type 'Unknown'."):
345345
self.type_map.get_dt_container_cls(data_type="Unknown")
346346

347347

0 commit comments

Comments
 (0)