Skip to content

Commit 8ca89e3

Browse files
authored
feat: support subquery expressions (#134)
1 parent ab158b5 commit 8ca89e3

File tree

8 files changed

+1963
-2
lines changed

8 files changed

+1963
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ This is work in progress still, things still to do:
1818
- [x] MultiOrList
1919
- [x] Cast
2020
- [x] Nested
21-
- [ ] Subquery
21+
- [x] Subquery
2222
- [ ] Serialization/Deserialization of Plan and Relations
2323
- [x] Plan
2424
- [x] PlanRel

expr/expression.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,10 @@ func ExprFromProto(e *proto.Expression, baseSchema *types.RecordType, reg Extens
317317
case *proto.Expression_Enum_:
318318
return nil, fmt.Errorf("%w: deprecated", substraitgo.ErrNotImplemented)
319319
case *proto.Expression_Subquery_:
320+
if reg.subqueryConverter == nil {
321+
return nil, fmt.Errorf("%w: subquery expressions require a subquery converter to be configured", substraitgo.ErrNotImplemented)
322+
}
323+
return reg.SubqueryFromProto(et.Subquery, baseSchema, reg)
320324
}
321325

322326
return nil, fmt.Errorf("%w: ExprFromProto: %s", substraitgo.ErrNotImplemented, e)

expr/expressions_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,172 @@ func TestCastVisit(t *testing.T) {
449449
})
450450
}
451451
}
452+
453+
func TestSubqueryExpressionRoundtrip(t *testing.T) {
454+
const substraitExtURI = "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
455+
// define extensions with no plan for now
456+
const planExt = `{
457+
"extensionUris": [
458+
{
459+
"extensionUriAnchor": 1,
460+
"uri": "` + substraitExtURI + `"
461+
}
462+
],
463+
"extensions": [],
464+
"relations": []
465+
}`
466+
467+
var planProto proto.Plan
468+
if err := protojson.Unmarshal([]byte(planExt), &planProto); err != nil {
469+
panic(err)
470+
}
471+
472+
// get the extension set and create registry with subquery handler
473+
extSet := ext.GetExtensionSet(&planProto)
474+
c := ext.GetDefaultCollectionWithNoError()
475+
476+
// Create extension registry with subquery handler properly
477+
baseReg := expr.NewExtensionRegistry(extSet, c)
478+
subqueryReg := &plan.ExpressionConverter{ExtensionRegistry: baseReg}
479+
480+
// Create a simple mock relation for subqueries - single column of int32
481+
mockSchema := types.NamedStruct{
482+
Names: []string{"col1"},
483+
Struct: types.StructType{
484+
Types: []types.Type{&types.Int32Type{}},
485+
},
486+
}
487+
mockRel := plan.NewBuilderDefault().NamedScan([]string{"test_table"}, mockSchema)
488+
489+
// Create base schema for needle expressions
490+
baseSchema := types.NewRecordTypeFromTypes([]types.Type{&types.Int32Type{}, &types.StringType{}})
491+
492+
tests := []struct {
493+
name string
494+
subExpr expr.Expression
495+
}{
496+
{
497+
name: "ScalarSubquery",
498+
subExpr: plan.NewScalarSubquery(mockRel),
499+
},
500+
{
501+
name: "InPredicateSubquery",
502+
subExpr: plan.NewInPredicateSubquery(
503+
[]expr.Expression{expr.NewPrimitiveLiteral(int32(42), false)},
504+
mockRel,
505+
),
506+
},
507+
{
508+
name: "InPredicateSubquery_MultipleNeedles",
509+
subExpr: func() expr.Expression {
510+
// Create a 2-column relation for multi-needle test
511+
twoColSchema := types.NamedStruct{
512+
Names: []string{"col1", "col2"},
513+
Struct: types.StructType{
514+
Types: []types.Type{&types.Int32Type{}, &types.StringType{}},
515+
},
516+
}
517+
twoColRel := plan.NewBuilderDefault().NamedScan([]string{"two_col_table"}, twoColSchema)
518+
519+
return plan.NewInPredicateSubquery(
520+
[]expr.Expression{
521+
expr.NewPrimitiveLiteral(int32(42), false),
522+
expr.NewPrimitiveLiteral("test", false),
523+
},
524+
twoColRel,
525+
)
526+
}(),
527+
},
528+
{
529+
name: "SetPredicateSubquery_EXISTS",
530+
subExpr: plan.NewSetPredicateSubquery(
531+
proto.Expression_Subquery_SetPredicate_PREDICATE_OP_EXISTS,
532+
mockRel,
533+
),
534+
},
535+
{
536+
name: "SetPredicateSubquery_UNIQUE",
537+
subExpr: plan.NewSetPredicateSubquery(
538+
proto.Expression_Subquery_SetPredicate_PREDICATE_OP_UNIQUE,
539+
mockRel,
540+
),
541+
},
542+
{
543+
name: "SetComparisonSubquery_ANY_EQ",
544+
subExpr: plan.NewSetComparisonSubquery(
545+
proto.Expression_Subquery_SetComparison_REDUCTION_OP_ANY,
546+
proto.Expression_Subquery_SetComparison_COMPARISON_OP_EQ,
547+
expr.NewPrimitiveLiteral(int32(42), false),
548+
mockRel,
549+
),
550+
},
551+
{
552+
name: "SetComparisonSubquery_ALL_GT",
553+
subExpr: plan.NewSetComparisonSubquery(
554+
proto.Expression_Subquery_SetComparison_REDUCTION_OP_ALL,
555+
proto.Expression_Subquery_SetComparison_COMPARISON_OP_GT,
556+
expr.NewPrimitiveLiteral(int32(100), false),
557+
mockRel,
558+
),
559+
},
560+
{
561+
name: "SetComparisonSubquery_ANY_NE",
562+
subExpr: plan.NewSetComparisonSubquery(
563+
proto.Expression_Subquery_SetComparison_REDUCTION_OP_ANY,
564+
proto.Expression_Subquery_SetComparison_COMPARISON_OP_NE,
565+
expr.NewPrimitiveLiteral(int32(0), false),
566+
mockRel,
567+
),
568+
},
569+
{
570+
name: "SetComparisonSubquery_ALL_LE",
571+
subExpr: plan.NewSetComparisonSubquery(
572+
proto.Expression_Subquery_SetComparison_REDUCTION_OP_ALL,
573+
proto.Expression_Subquery_SetComparison_COMPARISON_OP_LE,
574+
expr.NewPrimitiveLiteral(int32(50), false),
575+
mockRel,
576+
),
577+
},
578+
}
579+
580+
for _, tt := range tests {
581+
t.Run(tt.name, func(t *testing.T) {
582+
// Convert expression to protobuf
583+
protoExpr := tt.subExpr.ToProto()
584+
require.NotNil(t, protoExpr)
585+
require.NotNil(t, protoExpr.GetSubquery())
586+
587+
// Convert back from protobuf using ExprFromProto with subquery handler
588+
baseReg.SetSubqueryConverter(subqueryReg)
589+
fromProto, err := expr.ExprFromProto(protoExpr, baseSchema, baseReg)
590+
require.NoError(t, err)
591+
require.NotNil(t, fromProto)
592+
593+
// Verify that we got the right type of subquery back
594+
switch tt.subExpr.(type) {
595+
case *plan.ScalarSubquery:
596+
assert.IsType(t, &plan.ScalarSubquery{}, fromProto)
597+
case *plan.InPredicateSubquery:
598+
assert.IsType(t, &plan.InPredicateSubquery{}, fromProto)
599+
case *plan.SetPredicateSubquery:
600+
assert.IsType(t, &plan.SetPredicateSubquery{}, fromProto)
601+
case *plan.SetComparisonSubquery:
602+
assert.IsType(t, &plan.SetComparisonSubquery{}, fromProto)
603+
}
604+
605+
// Verify protobuf roundtrip
606+
roundtripProto := fromProto.ToProto()
607+
assert.True(t, pb.Equal(protoExpr, roundtripProto), "protobuf roundtrip failed")
608+
609+
// Verify basic properties
610+
assert.Equal(t, tt.subExpr.IsScalar(), fromProto.IsScalar())
611+
assert.True(t, tt.subExpr.GetType().Equals(fromProto.GetType()))
612+
613+
// Note: We don't test Equals() here because the current implementation
614+
// of isRelEqual() only does pointer equality, so relations created from
615+
// protobuf will never be equal to the original relations, even if they
616+
// have identical content. This is a known limitation noted in the TODO
617+
// comment in plan/subquery.go
618+
})
619+
}
620+
}

expr/field_reference.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ import (
1212
"golang.org/x/exp/slices"
1313
)
1414

15+
// RootRefType is a marker interface for types that can be used as a Root
16+
// reference in a FieldReference.
17+
//
18+
// A field reference is composed of two parts: a Root reference, which is the
19+
// output of an expression in this relation or a previous one, and a
20+
// ReferenceSegment or MaskedExpression, which allows referencing data within
21+
// that data structure - e.g. a field in a struct, or a value in a list or map.
1522
type RootRefType interface {
1623
isRootRef()
1724
}

expr/utils.go

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,41 @@
22

33
package expr
44

5-
import "github.com/substrait-io/substrait-go/v4/extensions"
5+
import (
6+
"github.com/substrait-io/substrait-go/v4/extensions"
7+
"github.com/substrait-io/substrait-go/v4/types"
8+
proto "github.com/substrait-io/substrait-protobuf/go/substraitpb"
9+
)
610

11+
// ExtensionRegistry provides functionality to resolve extension references and handle subquery expressions.
12+
// It combines an extensions.Set for looking up extension definitions with a Collection for extension metadata.
713
type ExtensionRegistry struct {
814
extensions.Set
915
c *extensions.Collection
16+
17+
// subqueryConverter is injected by the plan package to handle subquery expressions
18+
// TODO: We may want to consider refactoring to make a cleaner interface here
19+
subqueryConverter
20+
}
21+
22+
// subqueryConverter converts subqueries and the Relations within from the native
23+
// protobuf format into an Expression.
24+
//
25+
// This interface is private to avoid exposing the dependency cycle - a Subquery
26+
// contains a Plan, so the implementor of this has to exist in / import the plan
27+
// package, which we can't do here without creating a cycle with the expr
28+
// package.
29+
//
30+
// TODO: We may want to refactor this interface to be more generic or use a
31+
// different approach to avoid the cycle.
32+
type subqueryConverter interface {
33+
SubqueryFromProto(sub *proto.Expression_Subquery, baseSchema *types.RecordType, reg ExtensionRegistry) (Expression, error)
34+
}
35+
36+
// SetSubqueryConverter allows the plan package to inject a subquery converter.
37+
// This is an internal function used to break the dependency cycle between expr and plan packages.
38+
func (e *ExtensionRegistry) SetSubqueryConverter(converter subqueryConverter) {
39+
e.subqueryConverter = converter
1040
}
1141

1242
// NewExtensionRegistry creates a new registry. If you have an existing plan you can use GetExtensionSet() to

plan/builders.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,23 @@ type Builder interface {
149149
// GetRelBuilder returns an expr.RelBuilder that can be used to construct
150150
// relations which need multiple stages to build them.
151151
GetRelBuilder() *RelBuilder
152+
153+
// Subquery expression builder methods
154+
155+
// InPredicateSubquery creates an IN predicate subquery expression that checks
156+
// if the needles (left expressions) are contained in the haystack (right subquery).
157+
InPredicateSubquery(needles []expr.Expression, haystack Rel) (*InPredicateSubquery, error)
158+
159+
// SetPredicateSubquery creates a set predicate subquery expression that checks
160+
// if the subquery returns any rows.
161+
SetPredicateSubquery(input Rel, predicateOp SetPredicateOp) (*SetPredicateSubquery, error)
162+
163+
// ScalarSubquery creates a scalar subquery expression that returns a single value.
164+
ScalarSubquery(input Rel) (*ScalarSubquery, error)
165+
166+
// SetComparisonSubquery creates a set comparison subquery expression that checks
167+
// if the left expression is contained in the right subquery.
168+
SetComparisonSubquery(left expr.Expression, right Rel, reductionOp SetComparisonReductionOp, comparisonOp SetComparisonComparisonOp) (*SetComparisonSubquery, error)
152169
}
153170

154171
const FETCH_COUNT_ALL_RECORDS = -1
@@ -870,3 +887,87 @@ func (arb *AggregateRelBuilder) validate() error {
870887

871888
return nil
872889
}
890+
891+
func (b *builder) InPredicateSubquery(needles []expr.Expression, haystack Rel) (*InPredicateSubquery, error) {
892+
if haystack == nil {
893+
return nil, errNilInputRel
894+
}
895+
896+
if len(needles) == 0 {
897+
return nil, fmt.Errorf("%w: IN predicate subquery must have at least one needle expression",
898+
substraitgo.ErrInvalidExpr)
899+
}
900+
901+
for i, needle := range needles {
902+
if needle == nil {
903+
return nil, fmt.Errorf("%w: needle expression %d cannot be nil",
904+
substraitgo.ErrInvalidExpr, i)
905+
}
906+
}
907+
908+
// Validate that the number of needle expressions matches the number of columns in the haystack
909+
haystackSchema := haystack.RecordType()
910+
if len(needles) != int(haystackSchema.FieldCount()) {
911+
return nil, fmt.Errorf("%w: number of needle expressions (%d) must match number of columns in haystack (%d)",
912+
substraitgo.ErrInvalidExpr, len(needles), haystackSchema.FieldCount())
913+
}
914+
915+
return NewInPredicateSubquery(needles, haystack), nil
916+
}
917+
918+
// SetPredicateSubquery creates a subquery that tests for the existence or uniqueness of rows
919+
// in the input relation.
920+
func (b *builder) SetPredicateSubquery(input Rel, predicateOp SetPredicateOp) (*SetPredicateSubquery, error) {
921+
if input == nil {
922+
return nil, errNilInputRel
923+
}
924+
925+
if predicateOp == SetPredicateOpUnspecified {
926+
return nil, fmt.Errorf("predicateOp must be specified")
927+
}
928+
929+
return NewSetPredicateSubquery(
930+
predicateOp,
931+
input,
932+
), nil
933+
}
934+
935+
func (b *builder) ScalarSubquery(input Rel) (*ScalarSubquery, error) {
936+
if input == nil {
937+
return nil, errNilInputRel
938+
}
939+
940+
return NewScalarSubquery(input), nil
941+
}
942+
943+
// SetComparisonSubquery creates a subquery that compares a single expression against
944+
// a set of values from a relation using ANY or ALL operations with comparison operators.
945+
// The reductionOp determines whether to use ANY or ALL semantics, and the comparisonOp
946+
// specifies the comparison operator (e.g., =, !=, <, >, <=, >=).
947+
func (b *builder) SetComparisonSubquery(
948+
left expr.Expression,
949+
right Rel,
950+
reductionOp SetComparisonReductionOp,
951+
comparisonOp SetComparisonComparisonOp,
952+
) (*SetComparisonSubquery, error) {
953+
if reductionOp == SetComparisonReductionOpUnspecified {
954+
return nil, fmt.Errorf("reductionOp must be specified")
955+
}
956+
if comparisonOp == SetComparisonComparisonOpUnspecified {
957+
return nil, fmt.Errorf("comparisonOp must be specified")
958+
}
959+
960+
if left == nil {
961+
return nil, errNilInputRel
962+
}
963+
if right == nil {
964+
return nil, errNilInputRel
965+
}
966+
967+
return NewSetComparisonSubquery(
968+
reductionOp,
969+
comparisonOp,
970+
left,
971+
right,
972+
), nil
973+
}

0 commit comments

Comments
 (0)