@@ -884,6 +884,138 @@ class Wrapper(qv.Table):
884884 assert have .pairs .y .null_count == 0
885885
886886
887+ def test_nullable_subtable_with_extension_type ():
888+ """Test that nullable subtables with extension type columns can be created with nulls.
889+
890+ This test covers the issue where pa.array([None] * n, type=extension_type)
891+ would raise ArrowNotImplementedError for extension types like FixedShapeTensorType.
892+ The fix was to use pa.nulls(n, type=self.dtype) instead in SubTableColumn._nulls().
893+ """
894+ class TableWithExtension (qv .Table ):
895+ id = qv .Int64Column ()
896+ tensor = qv .Column (pa .fixed_shape_tensor (pa .float64 (), (3 , 3 )))
897+
898+ class Wrapper (qv .Table ):
899+ name = qv .StringColumn ()
900+ data = TableWithExtension .as_column (nullable = True )
901+
902+ # Test 1: All nulls - no data provided for subtable
903+ # This should not raise ArrowNotImplementedError when creating nulls
904+ # for the extension type column
905+ result = Wrapper .from_kwargs (name = ["a" , "b" , "c" ])
906+
907+ # Verify the nulls are created correctly
908+ assert result .name .null_count == 0
909+ assert result .data .id .null_count == 3
910+ # The tensor field should also have nulls
911+ assert result .data .tensor .null_count == 3
912+
913+ # Test 2: Round-trip with all nulls
914+ result2 = Wrapper .from_kwargs (name = result .name , data = result .data )
915+ assert result2 .name .null_count == 0
916+ assert result2 .data .id .null_count == 3
917+ assert result2 .data .tensor .null_count == 3
918+
919+ # Test 3: All data provided (no nulls)
920+ tensors = pa .FixedShapeTensorArray .from_numpy_ndarray (np .random .rand (3 , 3 , 3 ))
921+ data_with_values = TableWithExtension .from_kwargs (id = [1 , 2 , 3 ], tensor = tensors )
922+ result3 = Wrapper .from_kwargs (name = ["a" , "b" , "c" ], data = data_with_values )
923+ assert result3 .name .null_count == 0
924+ assert result3 .data .id .null_count == 0
925+ assert result3 .data .tensor .null_count == 0
926+
927+ # Test 4: Partial nulls - mix of null and non-null rows
928+ # Create a subtable with 5 rows, then use a mask to null out some rows
929+ tensors_partial = pa .FixedShapeTensorArray .from_numpy_ndarray (np .random .rand (5 , 3 , 3 ))
930+ data_partial = TableWithExtension .from_kwargs (id = [1 , 2 , 3 , 4 , 5 ], tensor = tensors_partial )
931+ wrapper_partial = Wrapper .from_kwargs (name = ["a" , "b" , "c" , "d" , "e" ], data = data_partial )
932+
933+ # Apply a mask to create partial nulls (nullify rows 1 and 3, keep 0, 2, 4)
934+ mask = pa .array ([True , False , True , False , True ])
935+ result4 = wrapper_partial .apply_mask (mask )
936+
937+ assert result4 .name .null_count == 0
938+ assert len (result4 ) == 3
939+ assert result4 .data .id .null_count == 0 # All remaining rows have data
940+ assert result4 .data .tensor .null_count == 0
941+
942+ # Verify the correct rows were kept
943+ assert result4 .name .to_pylist () == ["a" , "c" , "e" ]
944+ assert result4 .data .id .to_pylist () == [1 , 3 , 5 ]
945+
946+ # Test 5: Nullable extension type column with partial nulls within the column
947+ class TableWithNullableExtension (qv .Table ):
948+ id = qv .Int64Column ()
949+ tensor = qv .Column (pa .fixed_shape_tensor (pa .float64 (), (3 , 3 )), nullable = True )
950+
951+ class WrapperNullable (qv .Table ):
952+ name = qv .StringColumn ()
953+ data = TableWithNullableExtension .as_column (nullable = True )
954+
955+ # Create a FixedShapeTensorArray with some null tensor entries
956+ # Note: NumPy doesn't have nulls (only NaN), so we need to create the array
957+ # and then add nulls at the PyArrow level to make entire tensors null
958+
959+ # Build the array by concatenating valid tensors and nulls
960+ # Create individual valid tensors
961+ tensor1 = pa .FixedShapeTensorArray .from_numpy_ndarray (np .random .rand (1 , 3 , 3 ))
962+ tensor2 = pa .nulls (1 , type = pa .fixed_shape_tensor (pa .float64 (), (3 , 3 )))
963+ tensor3 = pa .FixedShapeTensorArray .from_numpy_ndarray (np .random .rand (1 , 3 , 3 ))
964+ tensor4 = pa .nulls (1 , type = pa .fixed_shape_tensor (pa .float64 (), (3 , 3 )))
965+
966+ # Concatenate them to create array: [valid, null, valid, null]
967+ tensor_array_with_nulls = pa .concat_arrays ([tensor1 , tensor2 , tensor3 , tensor4 ])
968+
969+ data_nullable = TableWithNullableExtension .from_kwargs (
970+ id = [10 , 20 , 30 , 40 ], tensor = tensor_array_with_nulls
971+ )
972+ result5 = WrapperNullable .from_kwargs (name = ["x" , "y" , "z" , "w" ], data = data_nullable )
973+
974+ assert result5 .name .null_count == 0
975+ assert result5 .data .id .null_count == 0 # id column has no nulls
976+ assert result5 .data .tensor .null_count == 2 # Tensor has 2 null entries (at indices 1 and 3)
977+ assert result5 .data .id .to_pylist () == [10 , 20 , 30 , 40 ]
978+
979+ # Verify which entries are null
980+ tensor_column = result5 .table .column ("data" ).flatten ()[1 ] # Get the tensor field
981+ assert tensor_column .is_null ().to_pylist () == [False , True , False , True ]
982+
983+ # Test 6: Tensors with NaN values within the tensor data (not null entries)
984+ # This tests that we can handle tensors containing NaN values
985+ class TableWithNanTensors (qv .Table ):
986+ id = qv .Int64Column ()
987+ tensor = qv .Column (pa .fixed_shape_tensor (pa .float64 (), (3 , 3 )))
988+
989+ class WrapperWithNans (qv .Table ):
990+ name = qv .StringColumn ()
991+ data = TableWithNanTensors .as_column (nullable = True )
992+
993+ # Create tensors with NaN values inside them
994+ tensors_with_nans = np .random .rand (4 , 3 , 3 )
995+ # Set some elements to NaN (not the entire tensor, just some values within)
996+ tensors_with_nans [0 , 1 , 1 ] = np .nan # First tensor has a NaN at position [1,1]
997+ tensors_with_nans [2 , 0 , 2 ] = np .nan # Third tensor has a NaN at position [0,2]
998+ tensors_with_nans [3 , 2 , 2 ] = np .nan # Fourth tensor has a NaN at position [2,2]
999+
1000+ tensor_array_with_nans = pa .FixedShapeTensorArray .from_numpy_ndarray (tensors_with_nans )
1001+ data_with_nans = TableWithNanTensors .from_kwargs (id = [100 , 200 , 300 , 400 ], tensor = tensor_array_with_nans )
1002+ result6 = WrapperWithNans .from_kwargs (name = ["p" , "q" , "r" , "s" ], data = data_with_nans )
1003+
1004+ # All entries should be non-null (the tensors exist, they just contain NaN values)
1005+ assert result6 .name .null_count == 0
1006+ assert result6 .data .id .null_count == 0
1007+ assert result6 .data .tensor .null_count == 0
1008+
1009+ # Verify the data can be retrieved and contains NaN values
1010+ retrieved_tensors = result6 .data .tensor .combine_chunks ().to_numpy_ndarray ()
1011+ assert np .isnan (retrieved_tensors [0 , 1 , 1 ])
1012+ assert np .isnan (retrieved_tensors [2 , 0 , 2 ])
1013+ assert np .isnan (retrieved_tensors [3 , 2 , 2 ])
1014+ # Verify some non-NaN values exist
1015+ assert not np .isnan (retrieved_tensors [0 , 0 , 0 ])
1016+ assert not np .isnan (retrieved_tensors [1 , 1 , 1 ])
1017+
1018+
8871019class TableWithAttributes (qv .Table ):
8881020 x = qv .Int64Column ()
8891021 y = qv .Int64Column ()
0 commit comments