Skip to content

Commit ab8cf3b

Browse files
rlymavaylon1
andauthored
Don't override col_cls in DynamicTable.add_column (#1091)
Co-authored-by: Matthew Avaylon <[email protected]>
1 parent 775fa3b commit ab8cf3b

File tree

4 files changed

+83
-23
lines changed

4 files changed

+83
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
### Added
2121
- Added script to check Python version support for HDMF dependencies. @rly [#1230](https://github.com/hdmf-dev/hdmf/pull/1230)
2222

23+
### Fixed
24+
- Fixed issue with `DynamicTable.add_column` not allowing subclasses of `DynamicTableRegion` or `EnumData`. @rly [#1091](https://github.com/hdmf-dev/hdmf/pull/1091)
25+
2326
## HDMF 3.14.6 (December 20, 2024)
2427

2528
### Enhancements

src/hdmf/common/io/table.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,11 @@ def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_i
7878
required=field_spec.required
7979
)
8080
dtype = cls._get_type(field_spec, type_map)
81+
column_conf['class'] = dtype
8182
if issubclass(dtype, DynamicTableRegion):
8283
# the spec does not know which table this DTR points to
8384
# the user must specify the table attribute on the DTR after it is generated
8485
column_conf['table'] = True
85-
else:
86-
column_conf['class'] = dtype
8786

8887
index_counter = 0
8988
index_name = attr_name

src/hdmf/common/table.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def _init_class_columns(self):
521521
description=col['description'],
522522
index=col.get('index', False),
523523
table=col.get('table', False),
524-
col_cls=col.get('class', VectorData),
524+
col_cls=col.get('class'),
525525
# Pass through extra kwargs for add_column that subclasses may have added
526526
**{k: col[k] for k in col.keys()
527527
if k not in DynamicTable.__reserved_colspec_keys})
@@ -564,10 +564,13 @@ def _set_dtr_targets(self, target_tables: dict):
564564
if not column_conf.get('table', False):
565565
raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table."
566566
% colname)
567-
self.add_column(name=column_conf['name'],
568-
description=column_conf['description'],
569-
index=column_conf.get('index', False),
570-
table=True)
567+
self.add_column(
568+
name=column_conf['name'],
569+
description=column_conf['description'],
570+
index=column_conf.get('index', False),
571+
table=True,
572+
col_cls=column_conf.get('class'),
573+
)
571574
if isinstance(self[colname], VectorIndex):
572575
col = self[colname].target
573576
else:
@@ -681,7 +684,7 @@ def add_row(self, **kwargs):
681684
index=col.get('index', False),
682685
table=col.get('table', False),
683686
enum=col.get('enum', False),
684-
col_cls=col.get('class', VectorData),
687+
col_cls=col.get('class'),
685688
# Pass through extra keyword arguments for add_column that
686689
# subclasses may have added
687690
**{k: col[k] for k in col.keys()
@@ -753,7 +756,7 @@ def __eq__(self, other):
753756
'default': False},
754757
{'name': 'enum', 'type': (bool, 'array_data'), 'default': False,
755758
'doc': ('whether or not this column contains data from a fixed set of elements')},
756-
{'name': 'col_cls', 'type': type, 'default': VectorData,
759+
{'name': 'col_cls', 'type': type, 'default': None,
757760
'doc': ('class to use to represent the column data. If table=True, this field is ignored and a '
758761
'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData '
759762
'object is used.')},
@@ -805,29 +808,39 @@ def add_column(self, **kwargs): # noqa: C901
805808
% (name, self.__class__.__name__, spec_index))
806809
warn(msg, stacklevel=3)
807810

808-
spec_col_cls = self.__uninit_cols[name].get('class', VectorData)
809-
if col_cls != spec_col_cls:
810-
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
811-
"col_cls argument. The predefined class spec will be ignored. "
812-
"Please ensure the new column complies with the spec. "
813-
"This will raise an error in a future version of HDMF."
814-
% (name, self.__class__.__name__, spec_col_cls))
815-
warn(msg, stacklevel=2)
816-
817811
ckwargs = dict(kwargs)
818812

819813
# Add table if it's been specified
820814
if table and enum:
821815
raise ValueError("column '%s' cannot be both a table region "
822816
"and come from an enumerable set of elements" % name)
817+
# Update col_cls if table is specified
823818
if table is not False:
824-
col_cls = DynamicTableRegion
819+
if col_cls is None:
820+
col_cls = DynamicTableRegion
825821
if isinstance(table, DynamicTable):
826822
ckwargs['table'] = table
823+
# Update col_cls if enum is specified
827824
if enum is not False:
828-
col_cls = EnumData
825+
if col_cls is None:
826+
col_cls = EnumData
829827
if isinstance(enum, (list, tuple, np.ndarray, VectorData)):
830828
ckwargs['elements'] = enum
829+
# Update col_cls to the default VectorData if col_cls is None
830+
if col_cls is None:
831+
col_cls = VectorData
832+
833+
if name in self.__uninit_cols: # column is a predefined optional column from the spec
834+
# check the given values against the predefined optional column spec. if they do not match, raise a warning
835+
# and ignore the given arguments. users should not be able to override these values
836+
spec_col_cls = self.__uninit_cols[name].get('class')
837+
if spec_col_cls is not None and col_cls != spec_col_cls:
838+
msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered "
839+
"col_cls argument. The predefined class spec will be ignored. "
840+
"Please ensure the new column complies with the spec. "
841+
"This will raise an error in a future version of HDMF."
842+
% (name, self.__class__.__name__, spec_col_cls))
843+
warn(msg, stacklevel=2)
831844

832845
# If the user provided a list of lists that needs to be indexed, then we now need to flatten the data
833846
# We can only create the index actual VectorIndex once we have the VectorData column so we compute
@@ -873,7 +886,7 @@ def add_column(self, **kwargs): # noqa: C901
873886
if col in self.__uninit_cols:
874887
self.__uninit_cols.pop(col)
875888

876-
if col_cls is EnumData:
889+
if issubclass(col_cls, EnumData):
877890
columns.append(col.elements)
878891
col.elements.parent = self
879892

tests/unit/common/test_generate_table.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
class TestDynamicDynamicTable(TestCase):
1717

1818
def setUp(self):
19+
20+
self.dtr_spec = DatasetSpec(
21+
data_type_def='CustomDTR',
22+
data_type_inc='DynamicTableRegion',
23+
doc='a test DynamicTableRegion column', # this is overridden where it is used
24+
)
25+
1926
self.dt_spec = GroupSpec(
2027
'A test extension that contains a dynamic table',
2128
data_type_def='TestTable',
@@ -99,14 +106,21 @@ def setUp(self):
99106
doc='a test column',
100107
dtype='float',
101108
quantity='?',
102-
)
109+
),
110+
DatasetSpec(
111+
data_type_inc='CustomDTR',
112+
name='optional_custom_dtr_col',
113+
doc='a test DynamicTableRegion column',
114+
quantity='?'
115+
),
103116
]
104117
)
105118

106119
from hdmf.spec.write import YAMLSpecWriter
107120
writer = YAMLSpecWriter(outdir='.')
108121

109122
self.spec_catalog = SpecCatalog()
123+
self.spec_catalog.register_spec(self.dtr_spec, 'test.yaml')
110124
self.spec_catalog.register_spec(self.dt_spec, 'test.yaml')
111125
self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml')
112126
self.namespace = SpecNamespace(
@@ -124,7 +138,7 @@ def setUp(self):
124138
self.test_dir = tempfile.mkdtemp()
125139
spec_fpath = os.path.join(self.test_dir, 'test.yaml')
126140
namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml')
127-
writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
141+
writer.write_spec(dict(datasets=[self.dtr_spec], groups=[self.dt_spec, self.dt_spec2]), spec_fpath)
128142
writer.write_namespace(self.namespace, namespace_fpath)
129143
self.namespace_catalog = NamespaceCatalog()
130144
hdmf_typemap = get_type_map()
@@ -133,6 +147,7 @@ def setUp(self):
133147
self.type_map.load_namespaces(namespace_fpath)
134148
self.manager = BuildManager(self.type_map)
135149

150+
self.CustomDTR = self.type_map.get_dt_container_cls('CustomDTR', CORE_NAMESPACE)
136151
self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE)
137152
self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE)
138153

@@ -228,6 +243,22 @@ def test_dynamic_table_region_non_dtr_target(self):
228243
self.TestDTRTable(name='test_dtr_table', description='my table',
229244
target_tables={'optional_col3': test_table})
230245

246+
def test_custom_dtr_class(self):
247+
test_table = self.TestTable(name='test_table', description='my test table')
248+
test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5)
249+
test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5)
250+
251+
test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table',
252+
target_tables={'optional_custom_dtr_col': test_table})
253+
254+
self.assertIsInstance(test_dtr_table['optional_custom_dtr_col'], self.CustomDTR)
255+
self.assertEqual(test_dtr_table['optional_custom_dtr_col'].description, "a test DynamicTableRegion column")
256+
self.assertIs(test_dtr_table['optional_custom_dtr_col'].table, test_table)
257+
258+
test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=0)
259+
test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_custom_dtr_col=1)
260+
self.assertEqual(test_dtr_table['optional_custom_dtr_col'].data, [0, 1])
261+
231262
def test_attribute(self):
232263
test_table = self.TestTable(name='test_table', description='my test table')
233264
assert test_table.my_col is not None
@@ -266,3 +297,17 @@ def test_roundtrip(self):
266297
for err in errors:
267298
raise Exception(err)
268299
self.reader.close()
300+
301+
def test_add_custom_dtr_column(self):
302+
test_table = self.TestTable(name='test_table', description='my test table')
303+
test_table.add_column(
304+
name='custom_dtr_column',
305+
description='this is a custom DynamicTableRegion column',
306+
col_cls=self.CustomDTR,
307+
)
308+
self.assertIsInstance(test_table['custom_dtr_column'], self.CustomDTR)
309+
self.assertEqual(test_table['custom_dtr_column'].description, 'this is a custom DynamicTableRegion column')
310+
311+
test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], custom_dtr_column=0)
312+
test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], custom_dtr_column=1)
313+
self.assertEqual(test_table['custom_dtr_column'].data, [0, 1])

0 commit comments

Comments
 (0)