Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,22 @@ func Doc(node ast.Node) string {
return ""
}

type ResolutionInfo struct {
Filename string
Program *ast.Program
}

// Resolve resolves a named reference to its target node.
// For `include`d names, it also returns the filename and *ast.Program
// where the node was found.
//
// The target can either be in the current program's scope or it can refer to
// an included file using dot notation. Included files must exist in one of the
// given search directories.
func Resolve(name string, program *ast.Program, dirs []string) (ast.Node, error) {
func Resolve(name string, program *ast.Program, dirs []string, parseCache map[string]*ParseRes) (ast.Node, *ResolutionInfo, error) {
defs := program.Definitions

var rInfo *ResolutionInfo
if strings.Contains(name, ".") {
parts := strings.SplitN(name, ".", 2)
fname := parts[0] + ".thrift"
Expand All @@ -70,12 +78,13 @@ func Resolve(name string, program *ast.Program, dirs []string) (ast.Node, error)
}
}
if ipath == "" {
return nil, fmt.Errorf("missing \"include\" for type reference %q", name)
return nil, nil, fmt.Errorf("missing \"include\" for type reference %q", name)
}

program, _, err := ParseFile(ipath, dirs)
program, _, f, err := ParseFile(ipath, dirs, parseCache)
rInfo = &ResolutionInfo{Filename: f, Program: program}
if err != nil {
return nil, err
return nil, nil, err
}

defs = program.Definitions
Expand All @@ -84,11 +93,11 @@ func Resolve(name string, program *ast.Program, dirs []string) (ast.Node, error)

for _, def := range defs {
if def.Info().Name == name {
return def, nil
return def, rInfo, nil
}
}

return nil, fmt.Errorf("%q could not be resolved", name)
return nil, nil, fmt.Errorf("%q could not be resolved", name)
}

// ResolveConstant resolves an [ast.ConstantReference] to its target node.
Expand All @@ -98,12 +107,12 @@ func Resolve(name string, program *ast.Program, dirs []string) (ast.Node, error)
// - "Enum.Value" (ast.EnumItem)
// - "include.Constant" (ast.Constant)
// - "include.Enum.Value" (ast.EnumItem)
func ResolveConstant(ref ast.ConstantReference, program *ast.Program, dirs []string) (ast.Node, error) {
func ResolveConstant(ref ast.ConstantReference, program *ast.Program, dirs []string, parseCache map[string]*ParseRes) (ast.Node, error) {
parts := strings.SplitN(ref.Name, ".", 3)

n, err := Resolve(parts[0], program, dirs)
n, _, err := Resolve(parts[0], program, dirs, parseCache)
if err != nil && len(parts) > 1 {
n, err = Resolve(parts[0]+"."+parts[1], program, dirs)
n, _, err = Resolve(parts[0]+"."+parts[1], program, dirs, parseCache)
}
if err != nil {
return n, fmt.Errorf("%q could not be resolved", ref.Name)
Expand All @@ -125,8 +134,8 @@ func ResolveConstant(ref ast.ConstantReference, program *ast.Program, dirs []str
// resolve the target node's own type. This is useful when the reference
// points to an [ast.Typedef] or [ast.Constant], for example, and the caller
// is primarily intererested in the target's ast.Type.
func ResolveType(ref ast.TypeReference, program *ast.Program, dirs []string) (ast.Node, error) {
n, err := Resolve(ref.Name, program, dirs)
func ResolveType(ref ast.TypeReference, program *ast.Program, dirs []string, parseCache map[string]*ParseRes) (ast.Node, error) {
n, _, err := Resolve(ref.Name, program, dirs, parseCache)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestResolveConstant(t *testing.T) {
}

for _, tt := range tests {
n, err := ResolveConstant(tt.ref, tt.prog, nil)
n, err := ResolveConstant(tt.ref, tt.prog, nil, make(map[string]*ParseRes))
if tt.err {
if err == nil {
t.Errorf("expected an error, got %s", n)
Expand Down
21 changes: 11 additions & 10 deletions check.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,14 @@ next:

// C is a type passed to all check functions to provide context.
type C struct {
Filename string
Dirs []string
Program *ast.Program
Check string
Messages Messages
logger *log.Logger
parseInfo *idl.Info
Filename string
Dirs []string
Program *ast.Program
Check string
Messages Messages
logger *log.Logger
parseInfo *idl.Info
ParseCache map[string]*ParseRes
}

func (c *C) pos(n ast.Node) ast.Position {
Expand Down Expand Up @@ -203,23 +204,23 @@ func (c *C) Errorf(node ast.Node, message string, args ...any) {

// Resolve resolves a name.
func (c *C) Resolve(name string) ast.Node {
if n, err := Resolve(name, c.Program, c.Dirs); err == nil {
if n, _, err := Resolve(name, c.Program, c.Dirs, c.ParseCache); err == nil {
return n
}
return nil
}

// ResolveConstant resolves a constant reference to its target.
func (c *C) ResolveConstant(ref ast.ConstantReference) ast.Node {
if n, err := ResolveConstant(ref, c.Program, c.Dirs); err == nil {
if n, err := ResolveConstant(ref, c.Program, c.Dirs, c.ParseCache); err == nil {
return n
}
return nil
}

// ResolveType resolves a type reference to its target type.
func (c *C) ResolveType(ref ast.TypeReference) ast.Node {
if n, err := ResolveType(ref, c.Program, c.Dirs); err == nil {
if n, err := ResolveType(ref, c.Program, c.Dirs, c.ParseCache); err == nil {
return n
}
return nil
Expand Down
7 changes: 4 additions & 3 deletions checks/checks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ func RunTests(t *testing.T, check *thriftcheck.Check, tests []Test) {

for _, tt := range tests {
c := &thriftcheck.C{
Filename: tt.name,
Program: tt.prog,
Check: check.Name,
Filename: tt.name,
Program: tt.prog,
Check: check.Name,
ParseCache: make(map[string]*thriftcheck.ParseRes),
}
if c.Filename == "" {
c.Filename = "t.thrift"
Expand Down
220 changes: 220 additions & 0 deletions checks/depth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
// Copyright 2025 Pinterest
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package checks

import (
"fmt"
"maps"
"path/filepath"
"slices"
"strconv"
"strings"

"github.com/pinterest/thriftcheck"
"go.uber.org/thriftrw/ast"
)

type config struct {
maxDepth int
maxDepthUnset bool
allowCycles bool
}

type structNode struct {
id string
fields []*ast.Field
source *source
}

type source struct {
filename string
program *ast.Program
}

type typeNode struct {
isBaseType bool
sourceNode structNode
ref typeRef
}

type typeRef struct {
filename string
name string
line int
col int
depth int
}

func NewStructNode(s *ast.Struct, f string, p *ast.Program) structNode {
return structNode{id: f + s.Name, fields: s.Fields, source: &source{filename: f, program: p}}
}

// CheckDepth returns a thriftcheck.Check that reports an error
// if a Struct, Union, or Exception exceeds a specified depth.
func CheckDepth(maxAllowedDepth int, allowCycles bool) thriftcheck.Check {
structIdToTypes := make(map[string]map[string]*typeNode)

return thriftcheck.NewCheck("depth", func(c *thriftcheck.C, s *ast.Struct) {
maxDepth := maxAllowedDepth

for _, a := range s.Annotations {
if a.Name == "maxDepth" {
i, err := strconv.Atoi(a.Value)
if err != nil {
c.Errorf(s, `value of %q for "maxDepth" annotation could not be parsed into an integer`, a.Value)
return
}
if i < 1 {
c.Errorf(s, `"maxDepth" annotations should be positive, but got %q`, i)
return
}
maxDepth = i
break
}
}

maxDepthUnset := maxDepth == 0

if maxDepthUnset && allowCycles {
return
}

depth, cycle, path := getDepth(
NewStructNode(s, c.Filename, c.Program), 1, 1,
make(map[string]bool), []*typeNode{}, structIdToTypes,
config{maxDepth: maxDepth, maxDepthUnset: maxDepthUnset, allowCycles: allowCycles}, c)

if (!maxDepthUnset && depth > maxDepth) || (cycle && !allowCycles) {
pathDetails := []string{}
accD := 1
for _, e := range path {
accD += e.ref.depth
pathDetails = append(
pathDetails,
fmt.Sprintf("\t%s:%d:%d (%s) +%d (%d)", e.ref.filename, e.ref.line, e.ref.col, e.ref.name, e.ref.depth, accD))
}

m := fmt.Sprintf("exceeded maximum depth of %d", maxDepth)
if cycle && !allowCycles {
m = "led to a cycle"
}
c.Errorf(s, "%s %s\n%s", s.Name, m, strings.Join(pathDetails, "\n"))
}
})
}

func getDepth(
s structNode, curD, maxD int, vis map[string]bool, path []*typeNode,
structIdToTypes map[string]map[string]*typeNode, cfg config, c *thriftcheck.C,
) (int, bool, []*typeNode) {
if vis[s.id] {
return curD, true, path
}

vis[s.id] = true

maxD = max(maxD, curD)
if !cfg.maxDepthUnset && maxD > cfg.maxDepth {
return maxD, false, path
}

expandStructFields(s, structIdToTypes, c)

var cycle bool
// We sort the keys to make the results deterministic.
// If a struct has multiple paths that exceed the max depth,
// the reported path should always be the same one.
// This is helpful to avoid flaky tests, but it makes the check slower.
for _, key := range slices.Sorted(maps.Keys(structIdToTypes[s.id])) {
t := structIdToTypes[s.id][key]

d, c, path := getDepth(t.sourceNode, curD+t.ref.depth, maxD, vis, append(path, t), structIdToTypes, cfg, c)

cycle = cycle || c
maxD = max(maxD, d)

if (!cfg.maxDepthUnset && maxD > cfg.maxDepth) || (cycle && !cfg.allowCycles) {
return maxD, cycle, path
}
}

vis[s.id] = false
return maxD, false, []*typeNode{}
}

func expandStructFields(s structNode, structIdToTypes map[string]map[string]*typeNode, c *thriftcheck.C) {
if structIdToTypes[s.id] == nil {
structIdToTypes[s.id] = make(map[string]*typeNode)
for _, f := range s.fields {
expandType(f.Type, 1, s.source, make(map[string]bool), structIdToTypes[s.id], c)
}
}
}

// expandType traverses an ast.Type, resolving references when needed,
// to store the deepest ast.BaseType and *ast.Struct types
// relative to the parent struct.
func expandType(t ast.Type, depth int, src *source, vis map[string]bool, deepestTypes map[string]*typeNode, c *thriftcheck.C) {
switch v := t.(type) {
case ast.BaseType:
updateTypeIfDeepest(v.String(), v.String(), src, v.Line, v.Column, depth-1,
&typeNode{isBaseType: true}, deepestTypes)
case ast.TypeReference:
name := v.String()
n, rInfo, err := thriftcheck.Resolve(name, src.program, []string{filepath.Dir(src.filename)}, c.ParseCache)
if err != nil {
return
}
newSrc := src
if strings.Contains(name, ".") {
newSrc = &source{filename: rInfo.Filename, program: rInfo.Program}
}

switch n := n.(type) {
case *ast.Constant:
expandType(n.Type, depth, newSrc, vis, deepestTypes, c)
case *ast.Typedef:
key := src.filename + n.Name
if vis[key] {
c.Warningf(t, "found a cycle resolving typedef %q", n.Name)
return
}
vis[key] = true
expandType(n.Type, depth, newSrc, vis, deepestTypes, c)
vis[key] = false
default:
if s, ok := n.(*ast.Struct); ok {
updateTypeIfDeepest(newSrc.filename+s.Name, name, src, v.Line, v.Column, depth,
&typeNode{sourceNode: NewStructNode(s, newSrc.filename, newSrc.program)}, deepestTypes)
}
}
case ast.MapType:
expandType(v.KeyType, depth+1, src, vis, deepestTypes, c)
expandType(v.ValueType, depth+1, src, vis, deepestTypes, c)
case ast.ListType:
expandType(v.ValueType, depth+1, src, vis, deepestTypes, c)
case ast.SetType:
expandType(v.ValueType, depth+1, src, vis, deepestTypes, c)
}
}

func updateTypeIfDeepest(key, name string, src *source, line, col, depth int, baseT *typeNode, deepestTypes map[string]*typeNode) {
if deepestTypes[key] == nil {
deepestTypes[key] = baseT
}
if depth > deepestTypes[key].ref.depth {
deepestTypes[key].ref = typeRef{name: name, filename: src.filename, line: line, col: col, depth: depth}
}
}
Loading
Loading