Skip to content

Commit ee854d9

Browse files
committed
Support extension types (like pa.FixedShapeTensorType)
1 parent 5e130d9 commit ee854d9

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

quivr/columns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def _nulls(self, n: int) -> pa.Array:
257257
# manually.
258258
#
259259
# See: https://github.com/apache/arrow/issues/37072
260-
return pa.array([None] * n, type=self.dtype)
260+
# return pa.array([None] * n, type=self.dtype)
261+
return pa.nulls(n, type=self.dtype)
261262

262263

263264
class Int8Column(Column):

test/test_tables.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,138 @@ class Wrapper(qv.Table):
884884
assert have.pairs.y.null_count == 0
885885

886886

887+
def test_nullable_subtable_with_extension_type():
888+
"""Test that nullable subtables with extension type columns can be created with nulls.
889+
890+
This test covers the issue where pa.array([None] * n, type=extension_type)
891+
would raise ArrowNotImplementedError for extension types like FixedShapeTensorType.
892+
The fix was to use pa.nulls(n, type=self.dtype) instead in SubTableColumn._nulls().
893+
"""
894+
class TableWithExtension(qv.Table):
895+
id = qv.Int64Column()
896+
tensor = qv.Column(pa.fixed_shape_tensor(pa.float64(), (3, 3)))
897+
898+
class Wrapper(qv.Table):
899+
name = qv.StringColumn()
900+
data = TableWithExtension.as_column(nullable=True)
901+
902+
# Test 1: All nulls - no data provided for subtable
903+
# This should not raise ArrowNotImplementedError when creating nulls
904+
# for the extension type column
905+
result = Wrapper.from_kwargs(name=["a", "b", "c"])
906+
907+
# Verify the nulls are created correctly
908+
assert result.name.null_count == 0
909+
assert result.data.id.null_count == 3
910+
# The tensor field should also have nulls
911+
assert result.data.tensor.null_count == 3
912+
913+
# Test 2: Round-trip with all nulls
914+
result2 = Wrapper.from_kwargs(name=result.name, data=result.data)
915+
assert result2.name.null_count == 0
916+
assert result2.data.id.null_count == 3
917+
assert result2.data.tensor.null_count == 3
918+
919+
# Test 3: All data provided (no nulls)
920+
tensors = pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(3, 3, 3))
921+
data_with_values = TableWithExtension.from_kwargs(id=[1, 2, 3], tensor=tensors)
922+
result3 = Wrapper.from_kwargs(name=["a", "b", "c"], data=data_with_values)
923+
assert result3.name.null_count == 0
924+
assert result3.data.id.null_count == 0
925+
assert result3.data.tensor.null_count == 0
926+
927+
# Test 4: Partial nulls - mix of null and non-null rows
928+
# Create a subtable with 5 rows, then use a mask to null out some rows
929+
tensors_partial = pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(5, 3, 3))
930+
data_partial = TableWithExtension.from_kwargs(id=[1, 2, 3, 4, 5], tensor=tensors_partial)
931+
wrapper_partial = Wrapper.from_kwargs(name=["a", "b", "c", "d", "e"], data=data_partial)
932+
933+
# Apply a mask to create partial nulls (nullify rows 1 and 3, keep 0, 2, 4)
934+
mask = pa.array([True, False, True, False, True])
935+
result4 = wrapper_partial.apply_mask(mask)
936+
937+
assert result4.name.null_count == 0
938+
assert len(result4) == 3
939+
assert result4.data.id.null_count == 0 # All remaining rows have data
940+
assert result4.data.tensor.null_count == 0
941+
942+
# Verify the correct rows were kept
943+
assert result4.name.to_pylist() == ["a", "c", "e"]
944+
assert result4.data.id.to_pylist() == [1, 3, 5]
945+
946+
# Test 5: Nullable extension type column with partial nulls within the column
947+
class TableWithNullableExtension(qv.Table):
948+
id = qv.Int64Column()
949+
tensor = qv.Column(pa.fixed_shape_tensor(pa.float64(), (3, 3)), nullable=True)
950+
951+
class WrapperNullable(qv.Table):
952+
name = qv.StringColumn()
953+
data = TableWithNullableExtension.as_column(nullable=True)
954+
955+
# Create a FixedShapeTensorArray with some null tensor entries
956+
# Note: NumPy doesn't have nulls (only NaN), so we need to create the array
957+
# and then add nulls at the PyArrow level to make entire tensors null
958+
959+
# Build the array by concatenating valid tensors and nulls
960+
# Create individual valid tensors
961+
tensor1 = pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(1, 3, 3))
962+
tensor2 = pa.nulls(1, type=pa.fixed_shape_tensor(pa.float64(), (3, 3)))
963+
tensor3 = pa.FixedShapeTensorArray.from_numpy_ndarray(np.random.rand(1, 3, 3))
964+
tensor4 = pa.nulls(1, type=pa.fixed_shape_tensor(pa.float64(), (3, 3)))
965+
966+
# Concatenate them to create array: [valid, null, valid, null]
967+
tensor_array_with_nulls = pa.concat_arrays([tensor1, tensor2, tensor3, tensor4])
968+
969+
data_nullable = TableWithNullableExtension.from_kwargs(
970+
id=[10, 20, 30, 40], tensor=tensor_array_with_nulls
971+
)
972+
result5 = WrapperNullable.from_kwargs(name=["x", "y", "z", "w"], data=data_nullable)
973+
974+
assert result5.name.null_count == 0
975+
assert result5.data.id.null_count == 0 # id column has no nulls
976+
assert result5.data.tensor.null_count == 2 # Tensor has 2 null entries (at indices 1 and 3)
977+
assert result5.data.id.to_pylist() == [10, 20, 30, 40]
978+
979+
# Verify which entries are null
980+
tensor_column = result5.table.column("data").flatten()[1] # Get the tensor field
981+
assert tensor_column.is_null().to_pylist() == [False, True, False, True]
982+
983+
# Test 6: Tensors with NaN values within the tensor data (not null entries)
984+
# This tests that we can handle tensors containing NaN values
985+
class TableWithNanTensors(qv.Table):
986+
id = qv.Int64Column()
987+
tensor = qv.Column(pa.fixed_shape_tensor(pa.float64(), (3, 3)))
988+
989+
class WrapperWithNans(qv.Table):
990+
name = qv.StringColumn()
991+
data = TableWithNanTensors.as_column(nullable=True)
992+
993+
# Create tensors with NaN values inside them
994+
tensors_with_nans = np.random.rand(4, 3, 3)
995+
# Set some elements to NaN (not the entire tensor, just some values within)
996+
tensors_with_nans[0, 1, 1] = np.nan # First tensor has a NaN at position [1,1]
997+
tensors_with_nans[2, 0, 2] = np.nan # Third tensor has a NaN at position [0,2]
998+
tensors_with_nans[3, 2, 2] = np.nan # Fourth tensor has a NaN at position [2,2]
999+
1000+
tensor_array_with_nans = pa.FixedShapeTensorArray.from_numpy_ndarray(tensors_with_nans)
1001+
data_with_nans = TableWithNanTensors.from_kwargs(id=[100, 200, 300, 400], tensor=tensor_array_with_nans)
1002+
result6 = WrapperWithNans.from_kwargs(name=["p", "q", "r", "s"], data=data_with_nans)
1003+
1004+
# All entries should be non-null (the tensors exist, they just contain NaN values)
1005+
assert result6.name.null_count == 0
1006+
assert result6.data.id.null_count == 0
1007+
assert result6.data.tensor.null_count == 0
1008+
1009+
# Verify the data can be retrieved and contains NaN values
1010+
retrieved_tensors = result6.data.tensor.combine_chunks().to_numpy_ndarray()
1011+
assert np.isnan(retrieved_tensors[0, 1, 1])
1012+
assert np.isnan(retrieved_tensors[2, 0, 2])
1013+
assert np.isnan(retrieved_tensors[3, 2, 2])
1014+
# Verify some non-NaN values exist
1015+
assert not np.isnan(retrieved_tensors[0, 0, 0])
1016+
assert not np.isnan(retrieved_tensors[1, 1, 1])
1017+
1018+
8871019
class TableWithAttributes(qv.Table):
8881020
x = qv.Int64Column()
8891021
y = qv.Int64Column()

0 commit comments

Comments
 (0)