Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
522383a
Bump substrait version
benbellick Sep 15, 2025
c08a435
parse urn from file and throw error if missing
benbellick Sep 17, 2025
4877706
validate URN format
benbellick Sep 17, 2025
81e3cbb
fix tests by adding urn to all simple extensions
benbellick Sep 17, 2025
1a07238
add test to ensure URNs correctly required and format validated
benbellick Sep 17, 2025
87bcf93
Full URI -> URN migration
benbellick Sep 17, 2025
a409b51
ensure no duplicate URI/URN + test
benbellick Sep 18, 2025
a1396c7
Use bimap to encapsulate mapping between uri and urn
benbellick Sep 18, 2025
09b5d9a
on parse, resolve by URN then URI
benbellick Sep 18, 2025
a4f1a1a
alter ToProto method to generate both URIs AND URNs
benbellick Sep 18, 2025
2e1f08b
Add test to ensure both URI and URN captured in ToProto
benbellick Sep 18, 2025
7ce2347
fix broken tests by adding URI info to produced plans
benbellick Sep 18, 2025
69b4f2b
add one more test to ensure that plans w/ only uri vs urn are equivalent
benbellick Sep 18, 2025
c9e23b4
small cleanup of error handling
benbellick Sep 18, 2025
9869bdc
Add in a few more tests just to be safe
benbellick Sep 18, 2025
b03fdb1
Fix golangci-lint issues
benbellick Sep 18, 2025
72716ca
fix pr comments
benbellick Sep 18, 2025
9bf79bf
remove public Collection() function on ExtensionRegistry
benbellick Sep 22, 2025
27a7cc5
Improve uri/urn resolution strategy
benbellick Sep 23, 2025
0d5610e
compare ToProto result directly in tests
benbellick Sep 23, 2025
ac9cf33
Add tests to test full conditions of uri/urn resolution
benbellick Sep 23, 2025
c6bdd57
fix typo in error message: uri -> urn
benbellick Sep 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ package substraitgo
import "errors"

var (
ErrNotImplemented = errors.New("not implemented")
ErrInvalidType = errors.New("invalid type")
ErrInvalidExpr = errors.New("invalid expression")
ErrNotFound = errors.New("not found")
ErrKeyExists = errors.New("key already exists")
ErrInvalidRel = errors.New("invalid relation")
ErrInvalidArg = errors.New("invalid argument")
ErrInvalidInputCount = errors.New("invalid input count")
ErrInvalidDialect = errors.New("invalid dialect")
ErrNotImplemented = errors.New("not implemented")
ErrInvalidType = errors.New("invalid type")
ErrInvalidExpr = errors.New("invalid expression")
ErrNotFound = errors.New("not found")
ErrKeyExists = errors.New("key already exists")
ErrInvalidRel = errors.New("invalid relation")
ErrInvalidArg = errors.New("invalid argument")
ErrInvalidInputCount = errors.New("invalid input count")
ErrInvalidDialect = errors.New("invalid dialect")
ErrInvalidSimpleExtention = errors.New("invalid simple extension")
ErrInvalidPlan = errors.New("invalid plan")
ErrExtensionURINotResolvable = errors.New("extension URI not resolvable")
)
18 changes: 9 additions & 9 deletions expr/binding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,33 @@ import (
var (
extReg = NewEmptyExtensionRegistry(extensions.GetDefaultCollectionWithNoError())
uPointRef = extReg.GetTypeAnchor(extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "extension_types.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "extension_types",
Name: "point",
})

subID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
Name: "subtract"}
addID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
Name: "add"}
indexInID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_set.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_set",
Name: "index_in"}
rankID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
Name: "rank"}
firstValueID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
Name: "first_value"}
extractID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_datetime.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_datetime",
Name: "extract"}
ntileID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
Name: "ntile"}
sumID = extensions.ID{
URI: extensions.SubstraitDefaultURIPrefix + "functions_arithmetic.yaml",
URN: extensions.SubstraitDefaultURNPrefix + "functions_arithmetic",
Name: "sum"}

boringSchema = types.NamedStruct{
Expand Down
15 changes: 8 additions & 7 deletions expr/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func TestExprBuilder(t *testing.T) {
b.ScalarFunc(subID).Args(b.RootRef(expr.NewStructFieldRef(3)),
b.Wrap(expr.NewLiteral(int32(3), false))), ""},
{"window func", "",
b.WindowFunc(rankID), "invalid expression: non-decomposable window or agg function '{https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml rank}' must use InitialToResult phase"},
b.WindowFunc(rankID), "invalid expression: non-decomposable window or agg function '{extension:io.substrait:functions_arithmetic rank}' must use InitialToResult phase"},
{"window func", "rank(; phase: AGGREGATION_PHASE_INITIAL_TO_RESULT, invocation: AGGREGATION_INVOCATION_UNSPECIFIED) => i64?",
b.WindowFunc(rankID).Phase(types.AggPhaseInitialToResult), ""},
{"window func",
Expand Down Expand Up @@ -118,6 +118,7 @@ func TestExprBuilder(t *testing.T) {
func TestCustomTypesInFunctionOutput(t *testing.T) {
custom := `%YAML 1.2
---
urn: extension:test:custom
types:
- name: custom_type1
- name: custom_type2
Expand Down Expand Up @@ -159,9 +160,9 @@ window_functions:

planBuilder := plan.NewBuilder(&collection)

customType1 := planBuilder.UserDefinedType("custom", "custom_type1")
customType2 := planBuilder.UserDefinedType("custom", "custom_type2")
customType3 := planBuilder.UserDefinedType("custom", "custom_type3")
customType1 := planBuilder.UserDefinedType("extension:test:custom", "custom_type1")
customType2 := planBuilder.UserDefinedType("extension:test:custom", "custom_type2")
customType3 := planBuilder.UserDefinedType("extension:test:custom", "custom_type3")

anyVal, err := anypb.New(expr.NewPrimitiveLiteral("foo", false).ToProto())
require.NoError(t, err)
Expand All @@ -173,7 +174,7 @@ window_functions:

// check scalar function
scalar, err := planBuilder.GetExprBuilder().ScalarFunc(extensions.ID{
URI: "custom",
URN: "extension:test:custom",
Name: "custom_function",
}).Args(
customLiteral,
Expand All @@ -188,7 +189,7 @@ window_functions:

// check aggregate function
aggr, err := planBuilder.GetExprBuilder().AggFunc(extensions.ID{
URI: "custom",
URN: "extension:test:custom",
Name: "custom_aggr",
}).Args(
customLiteral,
Expand All @@ -202,7 +203,7 @@ window_functions:

// check window function
window, err := planBuilder.GetExprBuilder().WindowFunc(extensions.ID{
URI: "custom",
URN: "extension:test:custom",
Name: "custom_window",
}).Args(
customLiteral,
Expand Down
19 changes: 13 additions & 6 deletions expr/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ type Expression interface {
// // it's a pre-order traversal
// if f, ok := e.(*ScalarFunction); ok {
// return &ScalarFunction{
// ID: ExtID{URI: "some other uri", Name: "some other func"},
// ID: ExtID{URN: "some other urn", Name: "some other func"},
// Args: f.Args,
// Options: f.Options,
// OutputType: f.OutputType,
Expand Down Expand Up @@ -1565,14 +1565,19 @@ type Extended struct {
BaseSchema types.NamedStruct
AdvancedExts *extensions.AdvancedExtension
ExpectedTypeURLs []string

reg ExtensionRegistry
}

func ExtendedFromProto(ex *proto.ExtendedExpression, c *extensions.Collection) (*Extended, error) {
extSet, err := extensions.GetExtensionSet(ex, c)
if err != nil {
return nil, err
}
var (
base = types.NewNamedStructFromProto(ex.BaseSchema)
extSet = extensions.GetExtensionSet(ex)
reg = NewExtensionRegistry(extSet, c)
refs = make([]ExpressionReference, len(ex.ReferredExpr))
base = types.NewNamedStructFromProto(ex.BaseSchema)
reg = NewExtensionRegistry(extSet, c)
refs = make([]ExpressionReference, len(ex.ReferredExpr))
)

for i, r := range ex.ReferredExpr {
Expand Down Expand Up @@ -1602,18 +1607,20 @@ func ExtendedFromProto(ex *proto.ExtendedExpression, c *extensions.Collection) (
BaseSchema: base,
AdvancedExts: ex.AdvancedExtensions,
ExpectedTypeURLs: ex.ExpectedTypeUrls,
reg: reg,
}, nil
}

func (ex *Extended) ToProto() *proto.ExtendedExpression {
uris, decls := ex.Extensions.ToProto()
urns, uris, decls := ex.Extensions.ToProto(ex.reg.Collection())
refs := make([]*proto.ExpressionReference, len(ex.ReferredExpr))
for i, ref := range ex.ReferredExpr {
refs[i] = ref.ToProto()
}

return &proto.ExtendedExpression{
Version: ex.Version,
ExtensionUrns: urns,
ExtensionUris: uris,
Extensions: decls,
BaseSchema: ex.BaseSchema.ToProto(),
Expand Down
Loading
Loading