Skip to content

Commit aa68de1

Browse files
authored
feat: add SetAdvancedExtension method to Rel interface (#138)
1 parent 15c6916 commit aa68de1

File tree

4 files changed

+164
-0
lines changed

4 files changed

+164
-0
lines changed

plan/common.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ func (rc *RelCommon) GetAdvancedExtension() *extensions.AdvancedExtension {
6565
return rc.advExtension
6666
}
6767

68+
func (rc *RelCommon) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
69+
existing := rc.advExtension
70+
rc.advExtension = advExtension
71+
return existing
72+
}
73+
6874
func (rc *RelCommon) Hint() *Hint {
6975
return rc.hint
7076
}

plan/plan.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,9 @@ type Rel interface {
323323
RecordType() types.RecordType
324324

325325
GetAdvancedExtension() *extensions.AdvancedExtension
326+
// SetAdvancedExtension sets an AdvancedExtension on this Rel, returning any existing one on this Rel. Use `nil` to remove any existing AdvancedExtension.
327+
SetAdvancedExtension(extension *extensions.AdvancedExtension) (existing *extensions.AdvancedExtension)
328+
326329
ToProto() *proto.Rel
327330
ToProtoPlanRel() *proto.PlanRel
328331

plan/relations.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ func (b *baseReadRel) Filter() expr.Expression { ret
104104
func (b *baseReadRel) BestEffortFilter() expr.Expression { return b.bestEffortFilter }
105105
func (b *baseReadRel) Projection() *expr.MaskExpression { return b.projection }
106106
func (b *baseReadRel) GetAdvancedExtension() *extensions.AdvancedExtension { return b.advExtension }
107+
func (b *baseReadRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
108+
existing := b.advExtension
109+
b.advExtension = advExtension
110+
return existing
111+
}
107112

108113
func (b *baseReadRel) SetProjection(p *expr.MaskExpression) {
109114
b.projection = p
@@ -632,6 +637,12 @@ func (lf *LocalFileReadRel) GetAdvancedExtension() *extensions.AdvancedExtension
632637
return lf.advExtension
633638
}
634639

640+
func (lf *LocalFileReadRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
641+
existing := lf.advExtension
642+
lf.advExtension = advExtension
643+
return existing
644+
}
645+
635646
func (lf *LocalFileReadRel) ToProto() *proto.Rel {
636647
items := make([]*proto.ReadRel_LocalFiles_FileOrFiles, len(lf.items))
637648
for i, f := range lf.items {
@@ -710,6 +721,11 @@ func (p *ProjectRel) Expressions() []expr.Expression { return p.exprs }
710721
func (p *ProjectRel) GetAdvancedExtension() *extensions.AdvancedExtension {
711722
return p.advExtension
712723
}
724+
func (p *ProjectRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
725+
existing := p.advExtension
726+
p.advExtension = advExtension
727+
return existing
728+
}
713729

714730
func (p *ProjectRel) ToProto() *proto.Rel {
715731
exprs := make([]*proto.Expression, len(p.exprs))
@@ -862,6 +878,11 @@ func (j *JoinRel) Type() JoinType { return j.joinType }
862878
func (j *JoinRel) GetAdvancedExtension() *extensions.AdvancedExtension {
863879
return j.advExtension
864880
}
881+
func (j *JoinRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
882+
existing := j.advExtension
883+
j.advExtension = advExtension
884+
return existing
885+
}
865886

866887
func (j *JoinRel) ToProto() *proto.Rel {
867888
outRel := &proto.JoinRel{
@@ -949,6 +970,11 @@ func (c *CrossRel) Right() Rel { return c.right }
949970
func (c *CrossRel) GetAdvancedExtension() *extensions.AdvancedExtension {
950971
return c.advExtension
951972
}
973+
func (c *CrossRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
974+
existing := c.advExtension
975+
c.advExtension = advExtension
976+
return existing
977+
}
952978

953979
func (c *CrossRel) ToProto() *proto.Rel {
954980
return &proto.Rel{
@@ -1015,6 +1041,11 @@ func (f *FetchRel) Count() int64 { return f.count }
10151041
func (f *FetchRel) GetAdvancedExtension() *extensions.AdvancedExtension {
10161042
return f.advExtension
10171043
}
1044+
func (f *FetchRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1045+
existing := f.advExtension
1046+
f.advExtension = advExtension
1047+
return existing
1048+
}
10181049

10191050
func (f *FetchRel) ToProto() *proto.Rel {
10201051
return &proto.Rel{
@@ -1131,6 +1162,11 @@ func (ar *AggregateRel) Measures() []AggRelMeasure { return ar.meas
11311162
func (ar *AggregateRel) GetAdvancedExtension() *extensions.AdvancedExtension {
11321163
return ar.advExtension
11331164
}
1165+
func (ar *AggregateRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1166+
existing := ar.advExtension
1167+
ar.advExtension = advExtension
1168+
return existing
1169+
}
11341170

11351171
func (ar *AggregateRel) ToProto() *proto.Rel {
11361172
groupingExpressionsProto := make([]*proto.Expression, len(ar.groupingExpressions))
@@ -1297,6 +1333,11 @@ func (sr *SortRel) Sorts() []expr.SortField { return sr.sorts }
12971333
func (sr *SortRel) GetAdvancedExtension() *extensions.AdvancedExtension {
12981334
return sr.advExtension
12991335
}
1336+
func (sr *SortRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1337+
existing := sr.advExtension
1338+
sr.advExtension = advExtension
1339+
return existing
1340+
}
13001341

13011342
func (sr *SortRel) ToProto() *proto.Rel {
13021343
sorts := make([]*proto.SortField, len(sr.sorts))
@@ -1382,6 +1423,11 @@ func (fr *FilterRel) Condition() expr.Expression { return fr.cond }
13821423
func (fr *FilterRel) GetAdvancedExtension() *extensions.AdvancedExtension {
13831424
return fr.advExtension
13841425
}
1426+
func (fr *FilterRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1427+
existing := fr.advExtension
1428+
fr.advExtension = advExtension
1429+
return existing
1430+
}
13851431

13861432
func (fr *FilterRel) ToProto() *proto.Rel {
13871433
return &proto.Rel{
@@ -1468,6 +1514,11 @@ func (s *SetRel) Op() SetOp { return s.op }
14681514
func (s *SetRel) GetAdvancedExtension() *extensions.AdvancedExtension {
14691515
return s.advExtension
14701516
}
1517+
func (s *SetRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1518+
existing := s.advExtension
1519+
s.advExtension = advExtension
1520+
return existing
1521+
}
14711522

14721523
func (s *SetRel) ToProto() *proto.Rel {
14731524
inputs := make([]*proto.Rel, len(s.inputs))
@@ -1733,6 +1784,11 @@ func (hr *HashJoinRel) Type() HashMergeJoinType { return hr.joinType }
17331784
func (hr *HashJoinRel) GetAdvancedExtension() *extensions.AdvancedExtension {
17341785
return hr.advExtension
17351786
}
1787+
func (hr *HashJoinRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1788+
existing := hr.advExtension
1789+
hr.advExtension = advExtension
1790+
return existing
1791+
}
17361792

17371793
func (hr *HashJoinRel) ToProto() *proto.Rel {
17381794
keysLeft := make([]*proto.Expression_FieldReference, len(hr.leftKeys))
@@ -1841,6 +1897,11 @@ func (mr *MergeJoinRel) Type() HashMergeJoinType { return mr.joinType }
18411897
func (mr *MergeJoinRel) GetAdvancedExtension() *extensions.AdvancedExtension {
18421898
return mr.advExtension
18431899
}
1900+
func (mr *MergeJoinRel) SetAdvancedExtension(advExtension *extensions.AdvancedExtension) *extensions.AdvancedExtension {
1901+
existing := mr.advExtension
1902+
mr.advExtension = advExtension
1903+
return existing
1904+
}
18441905

18451906
func (mr *MergeJoinRel) ToProto() *proto.Rel {
18461907
keysLeft := make([]*proto.Expression_FieldReference, len(mr.leftKeys))

plan/relations_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/substrait-io/substrait-go/v4/extensions"
1010
"github.com/substrait-io/substrait-go/v4/types"
1111
proto "github.com/substrait-io/substrait-protobuf/go/substraitpb"
12+
"google.golang.org/protobuf/types/known/anypb"
1213
)
1314

1415
func noOpRewrite(e expr.Expression) (expr.Expression, error) {
@@ -413,6 +414,99 @@ func TestRelations_Copy(t *testing.T) {
413414
}
414415
}
415416

417+
func TestRelations_AdvancedExtensions(t *testing.T) {
418+
extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError())
419+
aggregateFnID := extensions.ID{
420+
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
421+
Name: "avg",
422+
}
423+
aggregateFn, err := expr.NewAggregateFunc(extReg,
424+
aggregateFnID, nil, types.AggInvocationAll,
425+
types.AggPhaseInitialToResult, nil, createPrimitiveFloat(1.0))
426+
require.NoError(t, err)
427+
428+
aggregateRel := &AggregateRel{input: createVirtualTableReadRel(1),
429+
groupingExpressions: []expr.Expression{createPrimitiveFloat(1.0)},
430+
groupingReferences: [][]uint32{{0}},
431+
measures: []AggRelMeasure{{measure: aggregateFn, filter: expr.NewPrimitiveLiteral(false, false)}}}
432+
crossRel := &CrossRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2)}
433+
extensionLeafRel := &ExtensionLeafRel{}
434+
extensionMultiRel := &ExtensionMultiRel{inputs: []Rel{createVirtualTableReadRel(1), createVirtualTableReadRel(2)}}
435+
fetchRel := &FetchRel{input: createVirtualTableReadRel(1), offset: 1, count: 2}
436+
filterRel := &FilterRel{input: createVirtualTableReadRel(1), cond: expr.NewPrimitiveLiteral(true, false)}
437+
hashJoinRel := &HashJoinRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2), joinType: HashMergeInner, leftKeys: []*expr.FieldReference{}, rightKeys: []*expr.FieldReference{}, postJoinFilter: expr.NewPrimitiveLiteral(true, false)}
438+
joinRel := &JoinRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2), joinType: JoinTypeInner, expr: expr.NewPrimitiveLiteral(true, false), postJoinFilter: expr.NewPrimitiveLiteral(true, false)}
439+
localFileReadRel := &LocalFileReadRel{items: []FileOrFiles{{Path: "path"}}, baseReadRel: baseReadRel{filter: expr.NewPrimitiveLiteral(true, false), bestEffortFilter: expr.NewPrimitiveLiteral(true, false)}}
440+
mergeJoinRel := &MergeJoinRel{left: createVirtualTableReadRel(1), right: createVirtualTableReadRel(2), joinType: HashMergeInner, leftKeys: []*expr.FieldReference{}, rightKeys: []*expr.FieldReference{}, postJoinFilter: expr.NewPrimitiveLiteral(true, false)}
441+
namedTableReadRel := &NamedTableReadRel{names: []string{"mytest"}, baseReadRel: baseReadRel{filter: expr.NewPrimitiveLiteral(true, false), bestEffortFilter: expr.NewPrimitiveLiteral(true, false)}}
442+
projectRel := &ProjectRel{input: createVirtualTableReadRel(1), exprs: []expr.Expression{createPrimitiveFloat(1.0), createPrimitiveFloat(2.0)}}
443+
setRel := &SetRel{inputs: []Rel{createVirtualTableReadRel(1), createVirtualTableReadRel(2), createVirtualTableReadRel(3)}, op: SetOpUnionAll}
444+
sortRel := &SortRel{input: createVirtualTableReadRel(1), sorts: []expr.SortField{{Expr: createPrimitiveFloat(1.0), Kind: types.SortAscNullsFirst}}}
445+
virtualTableReadRel := &VirtualTableReadRel{values: []expr.VirtualTableExpressionValue{[]expr.Expression{&expr.PrimitiveLiteral[int64]{Value: 1}}}}
446+
namedTableWriteRel := &NamedTableWriteRel{input: namedTableReadRel}
447+
icebergTableReadRel := &IcebergTableReadRel{
448+
baseReadRel: baseReadRel{filter: expr.NewPrimitiveLiteral(true, false)},
449+
tableType: &Direct{
450+
MetadataUri: "s3://bucket/path/to/metadata.json",
451+
},
452+
}
453+
454+
relations := []Rel{
455+
aggregateRel,
456+
crossRel,
457+
extensionLeafRel,
458+
extensionMultiRel,
459+
fetchRel,
460+
filterRel,
461+
hashJoinRel,
462+
joinRel,
463+
localFileReadRel,
464+
mergeJoinRel,
465+
namedTableReadRel,
466+
projectRel,
467+
setRel,
468+
sortRel,
469+
virtualTableReadRel,
470+
namedTableWriteRel,
471+
icebergTableReadRel,
472+
}
473+
474+
val1, err := anypb.New(expr.NewPrimitiveLiteral("foo", false).ToProto())
475+
assert.NoError(t, err)
476+
477+
exampleAdvancedExtension1 := &extensions.AdvancedExtension{
478+
Optimization: []*anypb.Any{val1},
479+
Enhancement: val1,
480+
}
481+
482+
val2, err := anypb.New(expr.NewPrimitiveLiteral("bar", false).ToProto())
483+
assert.NoError(t, err)
484+
485+
exampleAdvancedExtension2 := &extensions.AdvancedExtension{
486+
Optimization: []*anypb.Any{val2},
487+
Enhancement: val2,
488+
}
489+
490+
for _, relation := range relations {
491+
// setting an extension should return the old/existing extension
492+
// setting an extension for the first time means the old extension should be nil
493+
oldExtension := relation.SetAdvancedExtension(exampleAdvancedExtension1)
494+
assert.Nil(t, oldExtension)
495+
assert.Equal(t, exampleAdvancedExtension1, relation.GetAdvancedExtension())
496+
497+
// setting it again
498+
oldExtension = relation.SetAdvancedExtension(exampleAdvancedExtension2)
499+
assert.Equal(t, exampleAdvancedExtension1, oldExtension)
500+
assert.Equal(t, exampleAdvancedExtension2, relation.GetAdvancedExtension())
501+
502+
// setting it to nil
503+
oldExtension = relation.SetAdvancedExtension(nil)
504+
assert.Equal(t, exampleAdvancedExtension2, oldExtension)
505+
assert.Nil(t, relation.GetAdvancedExtension())
506+
507+
}
508+
}
509+
416510
func TestAggregateRelToBuilder(t *testing.T) {
417511
extReg := expr.NewExtensionRegistry(extensions.NewSet(), extensions.GetDefaultCollectionWithNoError())
418512
aggregateFnID := extensions.ID{

0 commit comments

Comments
 (0)