Skip to content

Commit fd45ef9

Browse files
authored
feat: enable handling of URNs alongside URIs (#166)
BREAKING CHANGE: GetExtensionSet now consumes Collection BREAKING CHANGE: GetExtensionSet now returns an error BREAKING CHANGE: Set FindURI is now FindURN BREAKING CHANGE: Set ToProto now consumes Collection BREAKING CHANGE: Set ToProto now returns additional []SimpleExtensionURN BREAKING CHANGE: TopLevel interface has new GetExtensionUrns method
1 parent 1803339 commit fd45ef9

33 files changed

+1460
-437
lines changed

errors.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ package substraitgo
55
import "errors"
66

77
var (
8-
ErrNotImplemented = errors.New("not implemented")
9-
ErrInvalidType = errors.New("invalid type")
10-
ErrInvalidExpr = errors.New("invalid expression")
11-
ErrNotFound = errors.New("not found")
12-
ErrKeyExists = errors.New("key already exists")
13-
ErrInvalidRel = errors.New("invalid relation")
14-
ErrInvalidArg = errors.New("invalid argument")
15-
ErrInvalidInputCount = errors.New("invalid input count")
16-
ErrInvalidDialect = errors.New("invalid dialect")
8+
ErrNotImplemented = errors.New("not implemented")
9+
ErrInvalidType = errors.New("invalid type")
10+
ErrInvalidExpr = errors.New("invalid expression")
11+
ErrNotFound = errors.New("not found")
12+
ErrKeyExists = errors.New("key already exists")
13+
ErrInvalidRel = errors.New("invalid relation")
14+
ErrInvalidArg = errors.New("invalid argument")
15+
ErrInvalidInputCount = errors.New("invalid input count")
16+
ErrInvalidDialect = errors.New("invalid dialect")
17+
ErrInvalidSimpleExtention = errors.New("invalid simple extension")
18+
ErrInvalidPlan = errors.New("invalid plan")
19+
ErrExtensionURINotResolvable = errors.New("extension URI not resolvable")
1720
)

expr/binding_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,33 @@ import (
1414
var (
1515
extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
1616
uPointRef = extReg.GetTypeAnchor(extensions.ID{
17-
URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml",
17+
URN: extensions.SubstraitDefaultURNPrefix + "extension_types",
1818
Name: "point",
1919
})
2020

2121
subID = extensions.ID{
22-
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
22+
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
2323
Name: "subtract"}
2424
addID = extensions.ID{
25-
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
25+
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
2626
Name: "add"}
2727
indexInID = extensions.ID{
28-
URI: extensions.SubstraitDefaultURIPrefix + "functions_set.yaml",
28+
URN: extensions.SubstraitDefaultURNPrefix + "functions_set",
2929
Name: "index_in"}
3030
rankID = extensions.ID{
31-
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
31+
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
3232
Name: "rank"}
3333
firstValueID = extensions.ID{
34-
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
34+
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
3535
Name: "first_value"}
3636
extractID = extensions.ID{
37-
URI: extensions.SubstraitDefaultURIPrefix + "functions_datetime.yaml",
37+
URN: extensions.SubstraitDefaultURNPrefix + "functions_datetime",
3838
Name: "extract"}
3939
ntileID = extensions.ID{
40-
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
40+
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
4141
Name: "ntile"}
4242
sumID = extensions.ID{
43-
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
43+
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
4444
Name: "sum"}
4545

4646
boringSchema = types.NamedStruct{

expr/builder_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func TestExprBuilder(t *testing.T) {
6666
b.ScalarFunc(subID).Args(b.RootRef(expr.NewStructFieldRef(3)),
6767
b.Wrap(expr.NewLiteral(int32(3), false))), ""},
6868
{"window func", "",
69-
b.WindowFunc(rankID), "invalid expression: non-decomposable window or agg function '{https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml rank}' must use InitialToResult phase"},
69+
b.WindowFunc(rankID), "invalid expression: non-decomposable window or agg function '{extension:io.substrait:functions_arithmetic rank}' must use InitialToResult phase"},
7070
{"window func", "rank(; phase: AGGREGATION_PHASE_INITIAL_TO_RESULT, invocation: AGGREGATION_INVOCATION_UNSPECIFIED) => i64?",
7171
b.WindowFunc(rankID).Phase(types.AggPhaseInitialToResult), ""},
7272
{"window func",
@@ -118,6 +118,7 @@ func TestExprBuilder(t *testing.T) {
118118
func TestCustomTypesInFunctionOutput(t *testing.T) {
119119
custom := `%YAML 1.2
120120
---
121+
urn: extension:test:custom
121122
types:
122123
- name: custom_type1
123124
- name: custom_type2
@@ -159,9 +160,9 @@ window_functions:
159160

160161
planBuilder := plan.NewBuilder(&collection)
161162

162-
customType1 := planBuilder.UserDefinedType("custom", "custom_type1")
163-
customType2 := planBuilder.UserDefinedType("custom", "custom_type2")
164-
customType3 := planBuilder.UserDefinedType("custom", "custom_type3")
163+
customType1 := planBuilder.UserDefinedType("extension:test:custom", "custom_type1")
164+
customType2 := planBuilder.UserDefinedType("extension:test:custom", "custom_type2")
165+
customType3 := planBuilder.UserDefinedType("extension:test:custom", "custom_type3")
165166

166167
anyVal, err := anypb.New(expr.NewPrimitiveLiteral("foo", false).ToProto())
167168
require.NoError(t, err)
@@ -173,7 +174,7 @@ window_functions:
173174

174175
// check scalar function
175176
scalar, err := planBuilder.GetExprBuilder().ScalarFunc(extensions.ID{
176-
URI: "custom",
177+
URN: "extension:test:custom",
177178
Name: "custom_function",
178179
}).Args(
179180
customLiteral,
@@ -188,7 +189,7 @@ window_functions:
188189

189190
// check aggregate function
190191
aggr, err := planBuilder.GetExprBuilder().AggFunc(extensions.ID{
191-
URI: "custom",
192+
URN: "extension:test:custom",
192193
Name: "custom_aggr",
193194
}).Args(
194195
customLiteral,
@@ -202,7 +203,7 @@ window_functions:
202203

203204
// check window function
204205
window, err := planBuilder.GetExprBuilder().WindowFunc(extensions.ID{
205-
URI: "custom",
206+
URN: "extension:test:custom",
206207
Name: "custom_window",
207208
}).Args(
208209
customLiteral,

expr/expression.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ type Expression interface {
373373
// // it's a pre-order traversal
374374
// if f, ok := e.(*ScalarFunction); ok {
375375
// return &ScalarFunction{
376-
// ID: ExtID{URI: "some other uri", Name: "some other func"},
376+
// ID: ExtID{URN: "some other urn", Name: "some other func"},
377377
// Args: f.Args,
378378
// Options: f.Options,
379379
// OutputType: f.OutputType,
@@ -1565,14 +1565,19 @@ type Extended struct {
15651565
BaseSchema types.NamedStruct
15661566
AdvancedExts *extensions.AdvancedExtension
15671567
ExpectedTypeURLs []string
1568+
1569+
reg ExtensionRegistry
15681570
}
15691571

15701572
func ExtendedFromProto(ex *proto.ExtendedExpression, c *extensions.Collection) (*Extended, error) {
1573+
extSet, err := extensions.GetExtensionSet(ex, c)
1574+
if err != nil {
1575+
return nil, err
1576+
}
15711577
var (
1572-
base = types.NewNamedStructFromProto(ex.BaseSchema)
1573-
extSet = extensions.GetExtensionSet(ex)
1574-
reg = NewExtensionRegistry(extSet, c)
1575-
refs = make([]ExpressionReference, len(ex.ReferredExpr))
1578+
base = types.NewNamedStructFromProto(ex.BaseSchema)
1579+
reg = NewExtensionRegistry(extSet, c)
1580+
refs = make([]ExpressionReference, len(ex.ReferredExpr))
15761581
)
15771582

15781583
for i, r := range ex.ReferredExpr {
@@ -1602,18 +1607,20 @@ func ExtendedFromProto(ex *proto.ExtendedExpression, c *extensions.Collection) (
16021607
BaseSchema: base,
16031608
AdvancedExts: ex.AdvancedExtensions,
16041609
ExpectedTypeURLs: ex.ExpectedTypeUrls,
1610+
reg: reg,
16051611
}, nil
16061612
}
16071613

16081614
func (ex *Extended) ToProto() *proto.ExtendedExpression {
1609-
uris, decls := ex.Extensions.ToProto()
1615+
urns, uris, decls := ex.reg.ExtensionsToProto()
16101616
refs := make([]*proto.ExpressionReference, len(ex.ReferredExpr))
16111617
for i, ref := range ex.ReferredExpr {
16121618
refs[i] = ref.ToProto()
16131619
}
16141620

16151621
return &proto.ExtendedExpression{
16161622
Version: ex.Version,
1623+
ExtensionUrns: urns,
16171624
ExtensionUris: uris,
16181625
Extensions: decls,
16191626
BaseSchema: ex.BaseSchema.ToProto(),

0 commit comments

Comments
 (0)