Skip to content

Commit 7095b19

Browse files
gforsythcpcloud
andauthored
fix(field_index): get field index w.r.t. pre-join table schemata (#1078)
* fix(field_index): get field index w.r.t. pre-join table schemata JoinChains provide the schema of the joined table (which is great for Ibis) but for substrait we need the Field index computed with respect to the original table schemata. In practice, this means rolling through the tables in a JoinChain and computing the field index _without_ removing the join key Given Table 1 a: int b: int Table 2 a: int c: int JoinChain[r0] JoinLink[inner, r1] r0.a == r1.a values: a: r0.a b: r0.b c: r1.c If we ask for the field index of `c`, the JoinChain schema will give us an index of `2`, but it should be `3` because 0: table 1 a 1: table 1 b 2: table 2 a 3: table 2 c So now we pull out the correct JoinReference object and use that to index into the tables in the JoinChain and offset by the length of the schema of those preceding tables. * test(snapshots): update snapshots for fixed join indexing * fix: apply suggestions from review Co-authored-by: Phillip Cloud <[email protected]> --------- Co-authored-by: Phillip Cloud <[email protected]>
1 parent 3a8b0db commit 7095b19

File tree

6 files changed

+129
-19
lines changed

6 files changed

+129
-19
lines changed

ibis_substrait/compiler/translate.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,58 @@ def table_column(
674674
else:
675675
base_offset = 0
676676

677-
schema = op.rel.schema
678-
relative_offset = schema._name_locs[op.name]
677+
if isinstance(op.rel, ops.JoinChain):
678+
# JoinChains provide the schema of the joined table (which is great for Ibis)
679+
# but for substrait we need the Field index computed with respect to
680+
# the original table schemas. In practice, this means rolling through
681+
# the tables in a JoinChain and computing the field index _without_
682+
# removing the join key
683+
#
684+
# Given
685+
# Table 1
686+
# a: int
687+
# b: int
688+
#
689+
# Table 2
690+
# a: int
691+
# c: int
692+
#
693+
# JoinChain[r0]
694+
# JoinLink[inner, r1]
695+
# r0.a == r1.a
696+
# values:
697+
# a: r0.a
698+
# b: r0.b
699+
# c: r1.c
700+
#
701+
# If we ask for the field index of `c`, the JoinChain schema will give
702+
# us an index of `2`, but it should be `3` because
703+
#
704+
# 0: table 1 a
705+
# 1: table 1 b
706+
# 2: table 2 a
707+
# 3: table 2 c
708+
#
709+
710+
# List of join reference objects
711+
join_tables = op.rel.tables
712+
# Join reference containing the field we care about
713+
field_table = op.rel.values.get(op.name).rel
714+
# Index of that join reference in the list of join references
715+
field_table_index = join_tables.index(field_table)
716+
717+
# Offset by the number of columns in each preceding table
718+
join_table_offset = sum(
719+
len(join_tables[i].schema) for i in range(field_table_index)
720+
)
721+
# Then add on the index of the column in the table
722+
# Also in the event of renaming due to join collisions, resolve
723+
# the renamed column to the original name so we can pull it off the parent table
724+
orig_name = op.rel.values[op.name].name
725+
relative_offset = join_table_offset + field_table.schema._name_locs[orig_name]
726+
else:
727+
schema = op.rel.schema
728+
relative_offset = schema._name_locs[op.name]
679729
absolute_offset = base_offset + relative_offset
680730
return stalg.Expression(
681731
selection=stalg.Expression.FieldReference(

ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h07/tpc_h07.json

+9-5
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@
732732
"selection": {
733733
"directReference": {
734734
"structField": {
735-
"field": 1
735+
"field": 45
736736
}
737737
},
738738
"rootReference": {}
@@ -764,7 +764,9 @@
764764
"value": {
765765
"selection": {
766766
"directReference": {
767-
"structField": {}
767+
"structField": {
768+
"field": 41
769+
}
768770
},
769771
"rootReference": {}
770772
}
@@ -810,7 +812,7 @@
810812
"selection": {
811813
"directReference": {
812814
"structField": {
813-
"field": 1
815+
"field": 45
814816
}
815817
},
816818
"rootReference": {}
@@ -842,7 +844,9 @@
842844
"value": {
843845
"selection": {
844846
"directReference": {
845-
"structField": {}
847+
"structField": {
848+
"field": 41
849+
}
846850
},
847851
"rootReference": {}
848852
}
@@ -882,7 +886,7 @@
882886
"selection": {
883887
"directReference": {
884888
"structField": {
885-
"field": 2
889+
"field": 17
886890
}
887891
},
888892
"rootReference": {}

ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h08/tpc_h08.json

+3-3
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@
957957
"selection": {
958958
"directReference": {
959959
"structField": {
960-
"field": 3
960+
"field": 54
961961
}
962962
},
963963
"rootReference": {}
@@ -990,7 +990,7 @@
990990
"selection": {
991991
"directReference": {
992992
"structField": {
993-
"field": 4
993+
"field": 36
994994
}
995995
},
996996
"rootReference": {}
@@ -1034,7 +1034,7 @@
10341034
"selection": {
10351035
"directReference": {
10361036
"structField": {
1037-
"field": 5
1037+
"field": 4
10381038
}
10391039
},
10401040
"rootReference": {}

ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h09/tpc_h09.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@
767767
"selection": {
768768
"directReference": {
769769
"structField": {
770-
"field": 3
770+
"field": 29
771771
}
772772
},
773773
"rootReference": {}

ibis_substrait/tests/compiler/snapshots/test_tpch/test_compile/tpc_h21/tpc_h21.json

+12-8
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@
555555
"selection": {
556556
"directReference": {
557557
"structField": {
558-
"field": 1
558+
"field": 25
559559
}
560560
},
561561
"rootReference": {}
@@ -588,7 +588,7 @@
588588
"selection": {
589589
"directReference": {
590590
"structField": {
591-
"field": 2
591+
"field": 19
592592
}
593593
},
594594
"rootReference": {}
@@ -600,7 +600,7 @@
600600
"selection": {
601601
"directReference": {
602602
"structField": {
603-
"field": 3
603+
"field": 18
604604
}
605605
},
606606
"rootReference": {}
@@ -630,7 +630,7 @@
630630
"selection": {
631631
"directReference": {
632632
"structField": {
633-
"field": 6
633+
"field": 33
634634
}
635635
},
636636
"rootReference": {}
@@ -817,7 +817,9 @@
817817
"value": {
818818
"selection": {
819819
"directReference": {
820-
"structField": {}
820+
"structField": {
821+
"field": 7
822+
}
821823
},
822824
"rootReference": {}
823825
}
@@ -854,7 +856,7 @@
854856
"selection": {
855857
"directReference": {
856858
"structField": {
857-
"field": 4
859+
"field": 9
858860
}
859861
},
860862
"rootReference": {}
@@ -1063,7 +1065,9 @@
10631065
"value": {
10641066
"selection": {
10651067
"directReference": {
1066-
"structField": {}
1068+
"structField": {
1069+
"field": 7
1070+
}
10671071
},
10681072
"rootReference": {}
10691073
}
@@ -1100,7 +1104,7 @@
11001104
"selection": {
11011105
"directReference": {
11021106
"structField": {
1103-
"field": 4
1107+
"field": 9
11041108
}
11051109
},
11061110
"rootReference": {}

ibis_substrait/tests/compiler/test_compiler.py

+52
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,55 @@ def test_groupby_multiple_keys(compiler):
532532
# There should be one grouping with two separate expressions inside
533533
assert len(plan.aggregate.groupings) == 1
534534
assert len(plan.aggregate.groupings[0].grouping_expressions) == 2
535+
536+
537+
def test_join_chain_indexing_in_group_by(compiler):
538+
t1 = ibis.table([("a", int), ("b", int)], name="t1")
539+
t2 = ibis.table([("a", int), ("c", int)], name="t2")
540+
t3 = ibis.table([("a", int), ("d", int)], name="t3")
541+
t4 = ibis.table([("a", int), ("c", int)], name="t4")
542+
543+
join_chain = t1.join(t2, "a").join(t3, "a").join(t4, "a")
544+
# Indexing for chained join
545+
# t1: a: 0
546+
# t1: b: 1
547+
# t2: a: 2
548+
# t2: c: 3
549+
# t3: a: 4
550+
# t3: d: 5
551+
# t4: a: 6
552+
# t4: c: 7
553+
554+
expr = join_chain.group_by("d").count().select("d")
555+
plan = compiler.compile(expr)
556+
# Check that the field index for the group_by key is correctly indexed
557+
assert (
558+
plan.relations[0]
559+
.root.input.project.input.aggregate.groupings[0]
560+
.grouping_expressions[0]
561+
.selection.direct_reference.struct_field.field
562+
== 5
563+
)
564+
565+
expr = join_chain.group_by("c").count().select("c")
566+
plan = compiler.compile(expr)
567+
# Check that the field index for the group_by key is correctly indexed
568+
assert (
569+
plan.relations[0]
570+
.root.input.project.input.aggregate.groupings[0]
571+
.grouping_expressions[0]
572+
.selection.direct_reference.struct_field.field
573+
== 3
574+
)
575+
576+
# Group-by on a column that will be renamed by the joinchain
577+
expr = join_chain.group_by(t4.c).count().select("c")
578+
plan = compiler.compile(expr)
579+
# Check that the field index for the group_by key is correctly indexed
580+
assert (
581+
plan.relations[0]
582+
.root.input.project.input.aggregate.groupings[0]
583+
.grouping_expressions[0]
584+
.selection.direct_reference.struct_field.field
585+
== 7
586+
)

0 commit comments

Comments
 (0)