Skip to content

Commit bd507d3

Browse files
authored
Expose the supported types through the local type registry. (#122)
1 parent d5d8a8f commit bd507d3

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

functions/registries.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ type LocalTypeRegistry interface {
2828

2929
//GetTypeClasses() []types.TypeClass // TODO
3030

31+
// GetSupportedTypes returns the types supported by this dialect.
32+
GetSupportedTypes() map[string]types.Type
33+
3134
// IsTypeSupportedInTables checks whether a particular type is supported in tables.
3235
// Some types (such as INTERVAL) may only be supported in literal contexts.
3336
IsTypeSupportedInTables(typ types.Type) bool
@@ -46,6 +49,10 @@ type Dialect interface {
4649
// the subset of types supported by this dialect. This will return an error if there are
4750
// types declared in the dialect that are not available within the provided registry.
4851
LocalizeTypeRegistry(registry TypeRegistry) (LocalTypeRegistry, error)
52+
53+
// GetLocalTypeRegistry returns the last created type registry using this dialect or constructs
54+
// one using LocalizeTypeRegistry and a default type registry if one hasn't yet been made.
55+
GetLocalTypeRegistry() (LocalTypeRegistry, error)
4956
}
5057

5158
type FunctionName interface {

functions/types.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ func (t *localTypeRegistryImpl) GetLocalTypeFromSubstraitType(typ types.Type) (s
214214
return "", substraitgo.ErrNotFound
215215
}
216216

217+
func (t *localTypeRegistryImpl) GetSupportedTypes() map[string]types.Type {
218+
return t.localNameToType
219+
}
220+
217221
func (t *localTypeRegistryImpl) IsTypeSupportedInTables(typ types.Type) bool {
218222
if ti, ok := t.typeInfoMap[typ.ShortString()]; ok {
219223
return ti.supportedAsColumn

types/types_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ package types_test
44

55
import (
66
"fmt"
7+
"strings"
78
"testing"
89
"time"
910

1011
"github.com/stretchr/testify/assert"
12+
"github.com/stretchr/testify/require"
13+
"github.com/substrait-io/substrait-go/v3/functions"
1114
. "github.com/substrait-io/substrait-go/v3/types"
1215
"github.com/substrait-io/substrait-go/v3/types/integer_parameters"
1316
)
@@ -497,3 +500,45 @@ func TestGetTimeValueByPrecision(t *testing.T) {
497500
})
498501
}
499502
}
503+
504+
func TestGetSupportedTypes(t *testing.T) {
505+
dialect, err := functions.LoadDialect("test_dialect",
506+
strings.NewReader(`
507+
name: test_dialect
508+
type: sql
509+
dependencies:
510+
arithmetic:
511+
https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml
512+
supported_types:
513+
fp64:
514+
sql_type_name: float
515+
bool:
516+
sql_type_name: boolean
517+
varchar:
518+
sql_type_name: varchar
519+
date:
520+
sql_type_name: date
521+
time:
522+
sql_type_name: time
523+
pts:
524+
sql_type_name: timestamp
525+
ptstz:
526+
sql_type_name: timestamptz
527+
dec:
528+
sql_type_name: numeric
529+
scalar_functions:
530+
- name: arithmetic.add
531+
local_name: '+'
532+
infix: true
533+
required_options:
534+
overflow: SILENT
535+
rounding: TIE_TO_EVEN
536+
supported_kernels:
537+
- fp64_fp64
538+
`))
539+
require.NoError(t, err)
540+
typeRegistry, err := dialect.GetLocalTypeRegistry()
541+
require.NoError(t, err)
542+
st := typeRegistry.GetSupportedTypes()
543+
assert.Len(t, st, 8)
544+
}

0 commit comments

Comments
 (0)