Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions dbscan/dbscan.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type API struct {
allowUnknownColumns bool
// columnToIndexFieldMapCache stores a map of reflect.Type -> map[string][]int
columnToIndexFieldMapCache sync.Map
maxStructRecursionLevel int
}

// APIOption is a function type that changes API configuration.
Expand All @@ -73,10 +74,11 @@ type APIOption func(api *API)
// NewAPI creates a new API object with provided list of options.
func NewAPI(opts ...APIOption) (*API, error) {
api := &API{
structTagKey: "db",
columnSeparator: ".",
fieldMapperFn: SnakeCaseMapper,
allowUnknownColumns: false,
structTagKey: "db",
columnSeparator: ".",
fieldMapperFn: SnakeCaseMapper,
allowUnknownColumns: false,
maxStructRecursionLevel: 2,
}
for _, o := range opts {
o(api)
Expand Down Expand Up @@ -152,6 +154,14 @@ func WithAllowUnknownColumns(allowUnknownColumns bool) APIOption {
}
}

// WithMaxStructRecursionLevel limits recursion depth when traversing scanned types.
// The default is 2.
func WithMaxStructRecursionLevel(maxStructRecursionLevel int) APIOption {
return func(api *API) {
api.maxStructRecursionLevel = maxStructRecursionLevel
}
}

// StructTagKey returns the struct tag key used by the API.
func (api *API) StructTagKey() string {
return api.structTagKey
Expand Down
59 changes: 55 additions & 4 deletions dbscan/internal_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package dbscan

import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
Expand All @@ -11,10 +13,10 @@ type queryRowsFunc func(t *testing.T, query string) Rows

func DoTestRowScannerStartCalledExactlyOnce(t *testing.T, api *API, queryRows queryRowsFunc) {
query := `
SELECT *
FROM (
VALUES ('foo val', 'bar val'), ('foo val 2', 'bar val 2'), ('foo val 3', 'bar val 3')
) AS t (foo, bar)
select *
from (
values ('foo val', 'bar val'), ('foo val 2', 'bar val 2'), ('foo val 3', 'bar val 3')
) as t (foo, bar)
`
rows := queryRows(t, query)
defer rows.Close() //nolint: errcheck
Expand Down Expand Up @@ -42,3 +44,52 @@ func DoTestRowScannerStartCalledExactlyOnce(t *testing.T, api *API, queryRows qu

mockStart.AssertNumberOfCalls(t, "Execute", 1)
}

func TestColumnToFieldIndexMap(t *testing.T) {
type Node struct {
ID string
Parent *Node
}

type User struct {
ID int
Name string
}

type UserNode struct {
User
*Node
CreatedBy *string
}

testAPILevel1, err := NewAPI(WithMaxStructRecursionLevel(1))

assert.NoError(t, err)

testAPILevel2, err := NewAPI(WithMaxStructRecursionLevel(2))

assert.NoError(t, err)

type testCase struct {
api *API
structType reflect.Type
expectedCols []string
}

testCases := []testCase{
{testAPILevel1, reflect.TypeOf(Node{}), []string{"id", "parent"}},
{testAPILevel1, reflect.TypeOf(User{}), []string{"id", "name"}},
{testAPILevel1, reflect.TypeOf(UserNode{}), []string{"id", "name", "created_by", "parent"}},
{testAPILevel2, reflect.TypeOf(Node{}), []string{"id", "parent", "parent.id", "parent.parent"}},
{testAPILevel2, reflect.TypeOf(User{}), []string{"id", "name"}},
{testAPILevel2, reflect.TypeOf(UserNode{}), []string{"id", "name", "created_by", "parent", "parent.id", "parent.parent"}},
}
for _, tc := range testCases {
colIdxMap := tc.api.buildColumnToFieldIndexMap(tc.structType)
assert.Len(t, colIdxMap, len(tc.expectedCols))
for _, col := range tc.expectedCols {
_, exist := colIdxMap[col]
assert.True(t, exist)
}
}
}
36 changes: 34 additions & 2 deletions dbscan/structref.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
)

type toTraverse struct {
Parents map[reflect.Type]int
Type reflect.Type
IndexPrefix []int
ColumnPrefix string
Expand All @@ -26,11 +27,19 @@ func (api *API) getColumnToFieldIndexMap(structType reflect.Type) map[string][]i
func (api *API) buildColumnToFieldIndexMap(structType reflect.Type) map[string][]int {
result := make(map[string][]int, structType.NumField())
var queue []*toTraverse
queue = append(queue, &toTraverse{Type: structType, IndexPrefix: nil, ColumnPrefix: ""})
queue = append(
queue,
&toTraverse{
Type: structType,
IndexPrefix: nil,
ColumnPrefix: "",
Parents: map[reflect.Type]int{structType: 0},
},
)
for len(queue) > 0 {
traversal := queue[0]
queue = queue[1:]
structType := traversal.Type
structType = traversal.Type
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)

Expand Down Expand Up @@ -76,7 +85,30 @@ func (api *API) buildColumnToFieldIndexMap(structType reflect.Type) map[string][
columnPart = dbTag
}
columnPrefix := api.buildColumn(traversal.ColumnPrefix, columnPart)
parents := map[reflect.Type]int{childType: 0}

var abort bool

for k, v := range traversal.Parents {
if k == childType {
v++
}

if v >= api.maxStructRecursionLevel {
abort = true

break
}

parents[k] = v
}

if abort {
continue
}

queue = append(queue, &toTraverse{
Parents: parents,
Type: childType,
IndexPrefix: index,
ColumnPrefix: columnPrefix,
Expand Down