diff --git a/_docs/generator/template-funcs.go b/_docs/generator/template-funcs.go index 78c5234b8..1daca6375 100644 --- a/_docs/generator/template-funcs.go +++ b/_docs/generator/template-funcs.go @@ -17,10 +17,12 @@ import ( "strings" "unicode" + "golang.org/x/tools/go/packages" + "github.com/DataDog/orchestrion/internal/injector/aspect/advice" "github.com/DataDog/orchestrion/internal/injector/aspect/advice/code" "github.com/DataDog/orchestrion/internal/injector/aspect/join" - "golang.org/x/tools/go/packages" + "github.com/DataDog/orchestrion/internal/injector/typed" ) var ( @@ -65,7 +67,7 @@ func render(val any) (template.HTML, error) { templateName := "doc." switch val := val.(type) { - case join.Point, join.TypeName, join.FunctionOption: + case join.Point, typed.TypeName, join.FunctionOption: templateName += "join" case advice.Advice: templateName += "advice" diff --git a/internal/injector/aspect/advice/call.go b/internal/injector/aspect/advice/call.go index fbee8dcba..0693494bb 100644 --- a/internal/injector/aspect/advice/call.go +++ b/internal/injector/aspect/advice/call.go @@ -14,20 +14,20 @@ import ( "github.com/DataDog/orchestrion/internal/fingerprint" "github.com/DataDog/orchestrion/internal/injector/aspect/advice/code" "github.com/DataDog/orchestrion/internal/injector/aspect/context" - "github.com/DataDog/orchestrion/internal/injector/aspect/join" + "github.com/DataDog/orchestrion/internal/injector/typed" "github.com/DataDog/orchestrion/internal/yaml" "github.com/dave/dst" "github.com/goccy/go-yaml/ast" ) type appendArgs struct { - TypeName join.TypeName + TypeName typed.TypeName Templates []*code.Template } // AppendArgs appends arguments of a given type to the end of a function call. All arguments must be // of the same type, as they may be appended at the tail end of a variadic call. -func AppendArgs(typeName join.TypeName, templates ...*code.Template) *appendArgs { +func AppendArgs(typeName typed.TypeName, templates ...*code.Template) *appendArgs { return &appendArgs{typeName, templates} } @@ -92,7 +92,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) { Ellipsis: true, } - if importPath := a.TypeName.ImportPath(); importPath != "" { + if importPath := a.TypeName.ImportPath; importPath != "" { ctx.AddImport(importPath, inferPkgName(importPath)) } @@ -101,7 +101,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) { func (a *appendArgs) AddedImports() []string { imports := make([]string, 0, len(a.Templates)+1) - if argTypeImportPath := a.TypeName.ImportPath(); argTypeImportPath != "" { + if argTypeImportPath := a.TypeName.ImportPath; argTypeImportPath != "" { imports = append(imports, argTypeImportPath) } for _, t := range a.Templates { @@ -168,7 +168,7 @@ func init() { return nil, err } - tn, err := join.NewTypeName(args.TypeName) + tn, err := typed.NewTypeName(args.TypeName) if err != nil { return nil, err } diff --git a/internal/injector/aspect/advice/call_test.go b/internal/injector/aspect/advice/call_test.go index 91fc4f91f..6839912e2 100644 --- a/internal/injector/aspect/advice/call_test.go +++ b/internal/injector/aspect/advice/call_test.go @@ -11,7 +11,7 @@ import ( "github.com/DataDog/orchestrion/internal/injector/aspect/advice" "github.com/DataDog/orchestrion/internal/injector/aspect/advice/code" "github.com/DataDog/orchestrion/internal/injector/aspect/context" - "github.com/DataDog/orchestrion/internal/injector/aspect/join" + "github.com/DataDog/orchestrion/internal/injector/typed" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -19,23 +19,23 @@ import ( func TestAppendArgs(t *testing.T) { t.Run("AddedImports", func(t *testing.T) { type testCase struct { - argType join.TypeName + argType typed.TypeName args []*code.Template expectedImports []string } testCases := map[string]testCase{ "imports-none": { - argType: join.MustTypeName("any"), + argType: typed.Any, args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})}, }, "imports-from-arg-type": { - argType: join.MustTypeName("*net/http.Request"), + argType: typed.MustTypeName("*net/http.Request"), args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})}, expectedImports: []string{"net/http"}, }, "imports-from-templates": { - argType: join.MustTypeName("any"), + argType: typed.Any, args: []*code.Template{ code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/foo"}, context.GoLangVersion{}), code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/bar"}, context.GoLangVersion{}), diff --git a/internal/injector/aspect/advice/code/dot_function.go b/internal/injector/aspect/advice/code/dot_function.go index 088440941..e29b3e3d0 100644 --- a/internal/injector/aspect/advice/code/dot_function.go +++ b/internal/injector/aspect/advice/code/dot_function.go @@ -8,12 +8,8 @@ package code import ( "errors" "fmt" - "go/importer" - "go/types" - "strings" "github.com/DataDog/orchestrion/internal/injector/aspect/context" - "github.com/DataDog/orchestrion/internal/injector/aspect/join" "github.com/DataDog/orchestrion/internal/injector/typed" "github.com/dave/dst" ) @@ -162,6 +158,11 @@ func (s signature) ResultThatImplements(name string) (string, error) { return "", nil } + // Optimization: First, check for an exact match using the helper. + if index, found := typed.FindMatchingTypeName(s.Results, name); found { + return fieldAt(s.Results, index, "result") + } // If not found, fall through to type resolution. + // Resolve the interface type. iface, err := typed.ResolveInterfaceTypeByName(name) if err != nil { @@ -192,18 +193,38 @@ func (s signature) LastResultThatImplements(name string) (string, error) { return "", nil } + // Optimization: First, check for an exact match using TypeName parsing, finding the last one. + lastMatchIndex := -1 + if tn, err := typed.NewTypeName(name); err == nil { + currentIndex := 0 + for _, field := range s.Results.List { + if tn.Matches(field.Type) { + lastMatchIndex = currentIndex // Update last found index + } + // Increment index by the number of names in the field (or 1 if unnamed). + count := len(field.Names) + if count == 0 { + count = 1 + } + currentIndex += count + } + } + // If we found a match via TypeName, return it. + if lastMatchIndex != -1 { + return fieldAt(s.Results, lastMatchIndex, "result") + } // If parsing failed or no match, fall through to type resolution. + // Resolve the interface type. iface, err := typed.ResolveInterfaceTypeByName(name) if err != nil { + // Propagate error if interface resolution fails return "", fmt.Errorf("resolving interface type %q: %w", name, err) } - // First, we need to build a map of result fields to their indices - // that takes into account named and unnamed parameters. - var ( - fieldIndices = make(map[*dst.Field]int) - index = 0 - ) + // Fallback: Check using ExprImplements, iterating backward. + // Need field indices map again for this path. + fieldIndices := make(map[*dst.Field]int) + index := 0 for _, field := range s.Results.List { fieldIndices[field] = index count := len(field.Names) @@ -213,7 +234,6 @@ func (s signature) LastResultThatImplements(name string) (string, error) { index += count } - // Loop backward through the results list. for i := len(s.Results.List) - 1; i >= 0; i-- { field := s.Results.List[i] if typed.ExprImplements(s.context, field.Type, iface) { @@ -265,7 +285,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) { } func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) { - tn, err := join.NewTypeName(typeName) + tn, err := typed.NewTypeName(typeName) if err != nil { return "", err } @@ -292,118 +312,27 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er return "", nil } -// exprImplements checks if an expression's type implements an interface. -func exprImplements(ctx context.AdviceContext, expr dst.Expr, iface *types.Interface) bool { - actualType := ctx.ResolveType(expr) - if actualType == nil { - return false - } - - return typeImplements(actualType, iface) -} - -// typeImplements checks if a type implements an interface (including pointer receivers). -func typeImplements(t types.Type, iface *types.Interface) bool { - if t == nil || iface == nil { - return false - } - - // Direct implementation check. - if types.Implements(t, iface) { - return true - } - - return false -} - -// resolveInterfaceTypeByName takes an interface name as a string and resolves it to an interface type. -// It supports built-in interfaces (e.g. "error"), package qualified interfaces (e.g. "io.Reader"), -// and third-party package interfaces (e.g. "example.com/pkg.Interface"). -func resolveInterfaceTypeByName(name string) (*types.Interface, error) { - // Handle built-in types. - if obj := types.Universe.Lookup(name); obj != nil { - typeObj, ok := obj.(*types.TypeName) - if !ok { - return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj) - } - - typ := typeObj.Type() - if !types.IsInterface(typ) { - return nil, fmt.Errorf("type %s is not an interface", name) - } - - t, ok := typ.Underlying().(*types.Interface) - if !ok { - return nil, fmt.Errorf("type %s is not an interface", name) - } - - return t, nil - } - - // Handle package-qualified types (e.g., "io.Writer"). - pkgName, typeName := splitPackageAndName(name) - if pkgName == "" { - return nil, fmt.Errorf("invalid type name: %s", name) - } - - // Import the package - imp := importer.Default() - pkg, err := imp.Import(pkgName) - if err != nil { - return nil, fmt.Errorf("failed to import package %q: %w", pkgName, err) - } - - // Look up the type in the package's scope - obj := pkg.Scope().Lookup(typeName) - if obj == nil { - return nil, fmt.Errorf("type %q not found in package %q", typeName, pkgName) - } - - typeObj, ok := obj.(*types.TypeName) - if !ok { - return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj) - } - - typ := typeObj.Type() - if !types.IsInterface(typ) { - return nil, fmt.Errorf("type %s is not an interface", name) - } - - t, ok := typ.Underlying().(*types.Interface) - if !ok { - return nil, fmt.Errorf("type %s is not an interface", name) - } - - return t, nil -} - -// splitPackageAndName splits a fully qualified type name like "io.Reader" or "example.com/pkg.Type" -// into its package path and local name. -// Returns ("", "error") for built-in "error". -// Returns ("", "MyType") for unqualified "MyType". -func splitPackageAndName(fullName string) (pkgPath string, localName string) { - if !strings.Contains(fullName, ".") { - // Assume built-in type (like "error") or unqualified local type. - return "", fullName - } - lastDot := strings.LastIndex(fullName, ".") - pkgPath = fullName[:lastDot] - localName = fullName[lastDot+1:] - return pkgPath, localName -} - // FinalResultImplements returns whether the final result implements the provided interface type. func (s signature) FinalResultImplements(interfaceName string) (bool, error) { if s.Results == nil || len(s.Results.List) == 0 { return false, nil } - iface, err := resolveInterfaceTypeByName(interfaceName) + lastField := s.Results.List[len(s.Results.List)-1] + + // Optimization: First, check for an exact match using TypeName parsing. + // Note: Not using FindMatchingTypeName as we only need to check the last field. + if tn, err := typed.NewTypeName(interfaceName); err == nil { + if tn.Matches(lastField.Type) { + return true, nil + } + } // If parsing failed or no match, fall through to type resolution. + + iface, err := typed.ResolveInterfaceTypeByName(interfaceName) if err != nil { return false, fmt.Errorf("resolving interface type %q: %w", interfaceName, err) } // Check if the last field type implements the interface. - lastField := s.Results.List[len(s.Results.List)-1] - return exprImplements(s.context, lastField.Type, iface), nil + return typed.ExprImplements(s.context, lastField.Type, iface), nil } diff --git a/internal/injector/aspect/advice/struct.go b/internal/injector/aspect/advice/struct.go index bd62c5c5e..5cc199ba0 100644 --- a/internal/injector/aspect/advice/struct.go +++ b/internal/injector/aspect/advice/struct.go @@ -11,7 +11,7 @@ import ( "github.com/DataDog/orchestrion/internal/fingerprint" "github.com/DataDog/orchestrion/internal/injector/aspect/context" - "github.com/DataDog/orchestrion/internal/injector/aspect/join" + "github.com/DataDog/orchestrion/internal/injector/typed" "github.com/DataDog/orchestrion/internal/yaml" "github.com/dave/dst" "github.com/goccy/go-yaml/ast" @@ -19,11 +19,11 @@ import ( type addStructField struct { Name string - TypeName join.TypeName + TypeName typed.TypeName } // AddStructField adds a new synthetic field at the tail end of a struct declaration. -func AddStructField(fieldName string, fieldType join.TypeName) *addStructField { +func AddStructField(fieldName string, fieldType typed.TypeName) *addStructField { return &addStructField{fieldName, fieldType} } @@ -47,7 +47,7 @@ func (a *addStructField) Apply(ctx context.AdviceContext) (bool, error) { Type: a.TypeName.AsNode(), }) - if importPath := a.TypeName.ImportPath(); importPath != "" { + if importPath := a.TypeName.ImportPath; importPath != "" { // If the type name is qualified, we may need to import the package, too. _ = ctx.AddImport(importPath, inferPkgName(importPath)) } @@ -60,7 +60,7 @@ func (a *addStructField) Hash(h *fingerprint.Hasher) error { } func (a *addStructField) AddedImports() []string { - if path := a.TypeName.ImportPath(); path != "" { + if path := a.TypeName.ImportPath; path != "" { return []string{path} } return nil @@ -76,7 +76,7 @@ func init() { if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil { return nil, err } - tn, err := join.NewTypeName(spec.Type) + tn, err := typed.NewTypeName(spec.Type) if err != nil { return nil, err } diff --git a/internal/injector/aspect/join/declaration.go b/internal/injector/aspect/join/declaration.go index 3d49c8fc8..42ef9bbe8 100644 --- a/internal/injector/aspect/join/declaration.go +++ b/internal/injector/aspect/join/declaration.go @@ -13,6 +13,7 @@ import ( "github.com/DataDog/orchestrion/internal/fingerprint" "github.com/DataDog/orchestrion/internal/injector/aspect/context" "github.com/DataDog/orchestrion/internal/injector/aspect/may" + "github.com/DataDog/orchestrion/internal/injector/typed" "github.com/DataDog/orchestrion/internal/yaml" "github.com/dave/dst" "github.com/goccy/go-yaml/ast" @@ -72,15 +73,15 @@ func (i *declarationOf) Hash(h *fingerprint.Hasher) error { } type valueDeclaration struct { - TypeName TypeName + TypeName typed.TypeName } -func ValueDeclaration(typeName TypeName) *valueDeclaration { +func ValueDeclaration(typeName typed.TypeName) *valueDeclaration { return &valueDeclaration{typeName} } func (i *valueDeclaration) PackageMayMatch(ctx *may.PackageContext) may.MatchType { - return ctx.PackageImports(i.TypeName.ImportPath()) + return ctx.PackageImports(i.TypeName.ImportPath) } func (*valueDeclaration) FileMayMatch(_ *may.FileContext) may.MatchType { @@ -106,7 +107,7 @@ func (i *valueDeclaration) Matches(ctx context.AspectContext) bool { } func (i *valueDeclaration) ImpliesImported() []string { - if path := i.TypeName.ImportPath(); path != "" { + if path := i.TypeName.ImportPath; path != "" { return []string{path} } return nil @@ -140,7 +141,7 @@ func init() { return nil, err } - tn, err := NewTypeName(typeName) + tn, err := typed.NewTypeName(typeName) if err != nil { return nil, err } diff --git a/internal/injector/aspect/join/function.go b/internal/injector/aspect/join/function.go index 66043a864..8cc204ac9 100644 --- a/internal/injector/aspect/join/function.go +++ b/internal/injector/aspect/join/function.go @@ -145,26 +145,26 @@ func (fo functionName) Hash(h *fingerprint.Hasher) error { } type signature struct { - Arguments []TypeName - Results []TypeName + Arguments []typed.TypeName + Results []typed.TypeName } // Signature matches function declarations based on their arguments and return // value types. -func Signature(args []TypeName, ret []TypeName) FunctionOption { +func Signature(args []typed.TypeName, ret []typed.TypeName) FunctionOption { return &signature{Arguments: args, Results: ret} } func (fo *signature) packageMayMatch(ctx *may.PackageContext) may.MatchType { sum := may.Match for _, candidate := range fo.Arguments { - sum = sum.And(ctx.PackageImports(candidate.ImportPath())) + sum = sum.And(ctx.PackageImports(candidate.ImportPath)) if sum == may.NeverMatch { return may.NeverMatch } } for _, candidate := range fo.Results { - sum = sum.And(ctx.PackageImports(candidate.ImportPath())) + sum = sum.And(ctx.PackageImports(candidate.ImportPath)) if sum == may.NeverMatch { return may.NeverMatch } @@ -178,12 +178,12 @@ func (*signature) fileMayMatch(_ *may.FileContext) may.MatchType { func (fo *signature) impliesImported() (list []string) { for _, tn := range fo.Arguments { - if path := tn.ImportPath(); path != "" { + if path := tn.ImportPath; path != "" { list = append(list, path) } } for _, tn := range fo.Results { - if path := tn.ImportPath(); path != "" { + if path := tn.ImportPath; path != "" { list = append(list, path) } } @@ -225,8 +225,8 @@ func (fo *signature) evaluate(info functionInformation) bool { func (fo *signature) Hash(h *fingerprint.Hasher) error { return h.Named( "signature", - fingerprint.List[TypeName](fo.Arguments), - fingerprint.List[TypeName](fo.Results), + fingerprint.List[typed.TypeName](fo.Arguments), + fingerprint.List[typed.TypeName](fo.Results), ) } @@ -236,15 +236,15 @@ type signatureContains struct { // SignatureContains matches function declarations based on their arguments and // return value types in any order and does not require all arguments or return values to be present. -func SignatureContains(args []TypeName, ret []TypeName) FunctionOption { +func SignatureContains(args []typed.TypeName, ret []typed.TypeName) FunctionOption { return &signatureContains{signature{Arguments: args, Results: ret}} } func (fo *signatureContains) Hash(h *fingerprint.Hasher) error { return h.Named( "signature-contains", - fingerprint.List[TypeName](fo.Arguments), - fingerprint.List[TypeName](fo.Results), + fingerprint.List[typed.TypeName](fo.Arguments), + fingerprint.List[typed.TypeName](fo.Results), ) } @@ -262,7 +262,7 @@ func (fo *signatureContains) evaluate(info functionInformation) bool { // containsAnyType checks if any of the expected types match any of the actual types in the field list. // Returns false if either slice is empty or nil. -func containsAnyType(expectedTypes []TypeName, fieldList *dst.FieldList) bool { +func containsAnyType(expectedTypes []typed.TypeName, fieldList *dst.FieldList) bool { // Quick return if either side is empty. if len(expectedTypes) == 0 || fieldList == nil || len(fieldList.List) == 0 { return false @@ -281,15 +281,15 @@ func containsAnyType(expectedTypes []TypeName, fieldList *dst.FieldList) bool { } type receiver struct { - TypeName TypeName + TypeName typed.TypeName } -func Receiver(typeName TypeName) FunctionOption { +func Receiver(typeName typed.TypeName) FunctionOption { return &receiver{typeName} } func (fo *receiver) packageMayMatch(ctx *may.PackageContext) may.MatchType { - if ctx.ImportPath == fo.TypeName.ImportPath() { + if ctx.ImportPath == fo.TypeName.ImportPath { return may.Match } @@ -297,7 +297,7 @@ func (fo *receiver) packageMayMatch(ctx *may.PackageContext) may.MatchType { } func (fo *receiver) fileMayMatch(ctx *may.FileContext) may.MatchType { - return ctx.FileContains(fo.TypeName.Name()) + return ctx.FileContains(fo.TypeName.Name) } func (fo *receiver) evaluate(info functionInformation) bool { @@ -305,7 +305,7 @@ func (fo *receiver) evaluate(info functionInformation) bool { } func (fo *receiver) impliesImported() []string { - return []string{fo.TypeName.ImportPath()} + return []string{fo.TypeName.ImportPath} } func (fo *receiver) Hash(h *fingerprint.Hasher) error { @@ -397,6 +397,11 @@ func (fo *resultImplements) evaluate(info functionInformation) bool { return false } + // Optimization: First, check for an exact match using the helper. + if _, found := typed.FindMatchingTypeName(info.Type.Results, fo.InterfaceName); found { + return true // Found direct match + } // If not found, fall through to type resolution. + // Ensure the type resolver is available. if info.typeResolver == nil { return false @@ -462,6 +467,14 @@ func (fo *finalResultImplements) evaluate(info functionInformation) bool { return false } + // Optimization: First, check for an exact match using TypeName parsing. + if tn, err := typed.NewTypeName(fo.InterfaceName); err == nil { + lastField := info.Type.Results.List[len(info.Type.Results.List)-1] + if tn.Matches(lastField.Type) { + return true // Found direct match + } + } // If parsing failed or no match, fall through to type resolution. + // Ensure the type resolver is available. if info.typeResolver == nil { return false @@ -536,7 +549,7 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast. if err := yaml.NodeToValueContext(ctx, mapping.Values[0].Value, &arg); err != nil { return err } - tn, err := NewTypeName(arg) + tn, err := typed.NewTypeName(arg) if err != nil { return err } @@ -560,23 +573,23 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast. return fmt.Errorf("unexpected keys: %s", strings.Join(keys, ", ")) } - var args []TypeName + var args []typed.TypeName if len(sig.Args) > 0 { - args = make([]TypeName, len(sig.Args)) + args = make([]typed.TypeName, len(sig.Args)) for i, a := range sig.Args { var err error - if args[i], err = NewTypeName(a); err != nil { + if args[i], err = typed.NewTypeName(a); err != nil { return err } } } - var ret []TypeName + var ret []typed.TypeName if len(sig.Ret) > 0 { - ret = make([]TypeName, len(sig.Ret)) + ret = make([]typed.TypeName, len(sig.Ret)) for i, r := range sig.Ret { var err error - if ret[i], err = NewTypeName(r); err != nil { + if ret[i], err = typed.NewTypeName(r); err != nil { return err } } diff --git a/internal/injector/aspect/join/function_test.go b/internal/injector/aspect/join/function_test.go index 5d0f88fd0..bf055a28b 100644 --- a/internal/injector/aspect/join/function_test.go +++ b/internal/injector/aspect/join/function_test.go @@ -14,22 +14,23 @@ import ( "github.com/stretchr/testify/require" "github.com/DataDog/orchestrion/internal/fingerprint" + "github.com/DataDog/orchestrion/internal/injector/typed" ) func TestSignatureContains(t *testing.T) { tests := []struct { name string - args []TypeName - ret []TypeName + args []typed.TypeName + ret []typed.TypeName funcInfo functionInformation want bool }{ { name: "single argument matches", - args: []TypeName{ - {name: "string"}, + args: []typed.TypeName{ + {Name: "string"}, }, - ret: make([]TypeName, 0), + ret: make([]typed.TypeName, 0), funcInfo: functionInformation{ Type: &dst.FuncType{ Params: &dst.FieldList{ @@ -47,9 +48,9 @@ func TestSignatureContains(t *testing.T) { }, { name: "single return matches", - args: make([]TypeName, 0), - ret: []TypeName{ - {name: "error"}, + args: make([]typed.TypeName, 0), + ret: []typed.TypeName{ + {Name: "error"}, }, funcInfo: functionInformation{ Type: &dst.FuncType{ @@ -67,10 +68,10 @@ func TestSignatureContains(t *testing.T) { }, { name: "argument in any position matches", - args: []TypeName{ - {name: "string"}, + args: []typed.TypeName{ + {Name: "string"}, }, - ret: make([]TypeName, 0), + ret: make([]typed.TypeName, 0), funcInfo: functionInformation{ Type: &dst.FuncType{ Params: &dst.FieldList{ @@ -88,9 +89,9 @@ func TestSignatureContains(t *testing.T) { }, { name: "return in any position matches", - args: make([]TypeName, 0), - ret: []TypeName{ - {name: "error"}, + args: make([]typed.TypeName, 0), + ret: []typed.TypeName{ + {Name: "error"}, }, funcInfo: functionInformation{ Type: &dst.FuncType{ @@ -109,10 +110,10 @@ func TestSignatureContains(t *testing.T) { }, { name: "no match for empty fields", - args: []TypeName{ - {name: "string"}, + args: []typed.TypeName{ + {Name: "string"}, }, - ret: make([]TypeName, 0), + ret: make([]typed.TypeName, 0), funcInfo: functionInformation{ Type: &dst.FuncType{ Params: nil, @@ -123,11 +124,11 @@ func TestSignatureContains(t *testing.T) { }, { name: "no match for different type", - args: []TypeName{ - {name: "float64"}, + args: []typed.TypeName{ + {Name: "float64"}, }, - ret: []TypeName{ - {name: "byte"}, + ret: []typed.TypeName{ + {Name: "byte"}, }, funcInfo: functionInformation{ Type: &dst.FuncType{ @@ -147,10 +148,10 @@ func TestSignatureContains(t *testing.T) { }, { name: "complex type match", - args: []TypeName{ - {name: "CustomType", path: "pkg"}, + args: []typed.TypeName{ + {Name: "CustomType", ImportPath: "pkg"}, }, - ret: make([]TypeName, 0), + ret: make([]typed.TypeName, 0), funcInfo: functionInformation{ Type: &dst.FuncType{ Params: &dst.FieldList{ @@ -182,8 +183,8 @@ func TestSignatureContains(t *testing.T) { } func TestSignatureContainsHash(t *testing.T) { - args := []TypeName{{name: "string"}, {name: "int"}} - ret := []TypeName{{name: "error"}} + args := []typed.TypeName{{Name: "string"}, {Name: "int"}} + ret := []typed.TypeName{{Name: "error"}} fo := SignatureContains(args, ret) @@ -202,7 +203,7 @@ func TestSignatureContainsHash(t *testing.T) { assert.Equal(t, fp1, fp2, "Hash() gave different results for identical signatures") - fo3 := SignatureContains([]TypeName{{name: "float64"}}, ret) + fo3 := SignatureContains([]typed.TypeName{{Name: "float64"}}, ret) h3 := fingerprint.New() err = fo3.Hash(h3) require.NoError(t, err, "Hash failed") @@ -227,9 +228,9 @@ signature-contains: require.True(t, ok, "Expected *signatureContains, got %T", option.FunctionOption) require.Len(t, signatureContains.Arguments, 2, "Expected 2 arguments") - assert.Equal(t, "string", signatureContains.Arguments[0].Name(), "First argument should be string") - assert.Equal(t, "error", signatureContains.Arguments[1].Name(), "Second argument should be error") + assert.Equal(t, "string", signatureContains.Arguments[0].Name, "First argument should be string") + assert.Equal(t, "error", signatureContains.Arguments[1].Name, "Second argument should be error") require.Len(t, signatureContains.Results, 1, "Expected 1 result") - assert.Equal(t, "bool", signatureContains.Results[0].Name(), "Result should be bool") + assert.Equal(t, "bool", signatureContains.Results[0].Name, "Result should be bool") } diff --git a/internal/injector/aspect/join/join.go b/internal/injector/aspect/join/join.go index 66c120ee9..e759453e7 100644 --- a/internal/injector/aspect/join/join.go +++ b/internal/injector/aspect/join/join.go @@ -8,13 +8,9 @@ package join import ( - "fmt" - "regexp" - "github.com/DataDog/orchestrion/internal/fingerprint" "github.com/DataDog/orchestrion/internal/injector/aspect/context" "github.com/DataDog/orchestrion/internal/injector/aspect/may" - "github.com/dave/dst" ) // Point is the interface that abstracts selection of nodes where to inject @@ -37,114 +33,3 @@ type Point interface { fingerprint.Hashable } - -type TypeName struct { - // path is the import path that provides the type, or an empty string if the - // type is local. - path string - // name is the leaf (un-qualified) name of the type. - name string - // pointer determines whether the specified type is a pointer or not. - pointer bool -} - -// FIXME: this does not support all the type syntax, like: "chan Event" -var typeNameRe = regexp.MustCompile(`\A(\*)?\s*(?:([A-Za-z_][A-Za-z0-9_.-]+(?:/[A-Za-z_.-][A-Za-z0-9_.-]+)*)\.)?([A-Za-z_][A-Za-z0-9_]*)\z`) - -func NewTypeName(n string) (tn TypeName, err error) { - matches := typeNameRe.FindStringSubmatch(n) - if matches == nil { - err = fmt.Errorf("invalid TypeName syntax: %q", n) - return - } - - tn.pointer = matches[1] == "*" - tn.path = matches[2] - tn.name = matches[3] - return -} - -// MustTypeName is the same as NewTypeName, except it panics in case of an error. -func MustTypeName(n string) (tn TypeName) { - var err error - if tn, err = NewTypeName(n); err != nil { - panic(err) - } - return -} - -// ImportPath returns the import path for this type name, or a blank string if -// this refers to a local or built-in type. -func (n TypeName) ImportPath() string { - return n.path -} - -// Name returns the unqualified name of this type. -func (n TypeName) Name() string { - return n.name -} - -// Pointer returns whether this is a pointer type. -func (n TypeName) Pointer() bool { - return n.pointer -} - -// Matches determines whether the provided node represents the same type as this -// TypeName. -func (n TypeName) Matches(node dst.Expr) bool { - switch node := node.(type) { - case *dst.Ident: - return !n.pointer && n.path == node.Path && n.name == node.Name - - case *dst.SelectorExpr: - var path string - if ident, ok := node.X.(*dst.Ident); ok && ident.Path == "" { - path = ident.Name - } else { - return false - } - return !n.pointer && n.path == path && n.name == node.Sel.Name - - case *dst.StarExpr: - return n.pointer && (&TypeName{path: n.path, name: n.name}).Matches(node.X) - - case *dst.IndexExpr: - return !n.pointer && n.Matches(node.X) - - case *dst.IndexListExpr: - return !n.pointer && n.Matches(node.X) - - case *dst.InterfaceType: - // We only match the empty interface (as "any") - if len(node.Methods.List) != 0 { - return false - } - return n.path == "" && n.name == "any" - - default: - return false - } -} - -// MatchesDefinition determines whether the provided node matches the definition -// of this TypeName. The `importPath` argument determines the context in which -// the assertion is made. -func (n TypeName) MatchesDefinition(node dst.Expr, importPath string) bool { - if n.path != importPath { - return false - } - return (&TypeName{name: n.name, pointer: n.pointer}).Matches(node) -} - -func (n *TypeName) AsNode() dst.Expr { - ident := dst.NewIdent(n.name) - ident.Path = n.path - if n.pointer { - return &dst.StarExpr{X: ident} - } - return ident -} - -func (n TypeName) Hash(h *fingerprint.Hasher) error { - return h.Named("type-name", fingerprint.String(n.name), fingerprint.String(n.path), fingerprint.Bool(n.pointer)) -} diff --git a/internal/injector/aspect/join/struct.go b/internal/injector/aspect/join/struct.go index 22f83553e..304dbc0b4 100644 --- a/internal/injector/aspect/join/struct.go +++ b/internal/injector/aspect/join/struct.go @@ -13,31 +13,32 @@ import ( "github.com/DataDog/orchestrion/internal/fingerprint" "github.com/DataDog/orchestrion/internal/injector/aspect/context" "github.com/DataDog/orchestrion/internal/injector/aspect/may" + "github.com/DataDog/orchestrion/internal/injector/typed" "github.com/DataDog/orchestrion/internal/yaml" "github.com/dave/dst" "github.com/goccy/go-yaml/ast" ) type structDefinition struct { - TypeName TypeName + TypeName typed.TypeName } // StructDefinition matches the definition of a particular struct given its fully qualified name. -func StructDefinition(typeName TypeName) *structDefinition { +func StructDefinition(typeName typed.TypeName) *structDefinition { return &structDefinition{ TypeName: typeName, } } func (s *structDefinition) ImpliesImported() []string { - if path := s.TypeName.ImportPath(); path != "" { + if path := s.TypeName.ImportPath; path != "" { return []string{path} } return nil } func (s *structDefinition) PackageMayMatch(ctx *may.PackageContext) may.MatchType { - if ctx.ImportPath == s.TypeName.ImportPath() { + if ctx.ImportPath == s.TypeName.ImportPath { return may.Match } @@ -49,13 +50,13 @@ func (*structDefinition) FileMayMatch(ctx *may.FileContext) may.MatchType { } func (s *structDefinition) Matches(ctx context.AspectContext) bool { - if s.TypeName.pointer { + if s.TypeName.Pointer { // We can't ever match a pointer definition return false } spec, ok := ctx.Node().(*dst.TypeSpec) - if !ok || spec.Name == nil || spec.Name.Name != s.TypeName.name { + if !ok || spec.Name == nil || spec.Name.Name != s.TypeName.Name { return false } @@ -63,7 +64,7 @@ func (s *structDefinition) Matches(ctx context.AspectContext) bool { return false } - return ctx.ImportPath() == s.TypeName.path + return ctx.ImportPath() == s.TypeName.ImportPath } func (s *structDefinition) Hash(h *fingerprint.Hasher) error { @@ -73,7 +74,7 @@ func (s *structDefinition) Hash(h *fingerprint.Hasher) error { type ( StructLiteralMatch int structLiteral struct { - TypeName TypeName + TypeName typed.TypeName Field string Match StructLiteralMatch } @@ -93,7 +94,7 @@ const ( ) // StructLiteralField matches a specific field in struct literals of the designated type. -func StructLiteralField(typeName TypeName, field string) *structLiteral { +func StructLiteralField(typeName typed.TypeName, field string) *structLiteral { return &structLiteral{ TypeName: typeName, Field: field, @@ -102,7 +103,7 @@ func StructLiteralField(typeName TypeName, field string) *structLiteral { // StructLiteral matches struct literal expressions of the designated type, filtered by the // specified match type. -func StructLiteral(typeName TypeName, match StructLiteralMatch) *structLiteral { +func StructLiteral(typeName typed.TypeName, match StructLiteralMatch) *structLiteral { return &structLiteral{ TypeName: typeName, Match: match, @@ -110,14 +111,14 @@ func StructLiteral(typeName TypeName, match StructLiteralMatch) *structLiteral { } func (s *structLiteral) ImpliesImported() []string { - if path := s.TypeName.ImportPath(); path != "" { + if path := s.TypeName.ImportPath; path != "" { return []string{path} } return nil } func (s *structLiteral) PackageMayMatch(ctx *may.PackageContext) may.MatchType { - return ctx.PackageImports(s.TypeName.ImportPath()) + return ctx.PackageImports(s.TypeName.ImportPath) } func (*structLiteral) FileMayMatch(_ *may.FileContext) may.MatchType { @@ -185,11 +186,11 @@ func init() { return nil, err } - tn, err := NewTypeName(spec) + tn, err := typed.NewTypeName(spec) if err != nil { return nil, err } - if tn.pointer { + if tn.Pointer { return nil, fmt.Errorf("struct-definition type must not be a pointer (got %q)", spec) } @@ -205,7 +206,7 @@ func init() { return nil, err } - tn, err := NewTypeName(spec.Type) + tn, err := typed.NewTypeName(spec.Type) if err != nil { return nil, err } diff --git a/internal/injector/config/builtin.go b/internal/injector/config/builtin.go index b5edf3a8f..73dcb318e 100644 --- a/internal/injector/config/builtin.go +++ b/internal/injector/config/builtin.go @@ -11,6 +11,7 @@ import ( "github.com/DataDog/orchestrion/internal/injector/aspect/advice/code" "github.com/DataDog/orchestrion/internal/injector/aspect/context" "github.com/DataDog/orchestrion/internal/injector/aspect/join" + "github.com/DataDog/orchestrion/internal/injector/typed" ) var builtIn = configGo{ @@ -21,7 +22,7 @@ var builtIn = configGo{ ID: "built.WithOrchestrion", TracerInternal: true, // This is safe to apply in the tracer itself JoinPoint: join.AllOf( - join.ValueDeclaration(join.MustTypeName("bool")), + join.ValueDeclaration(typed.Bool), join.OneOf( join.DeclarationOf("github.com/DataDog/orchestrion/runtime/built", "WithOrchestrion"), join.Directive("orchestrion:enabled"), @@ -38,7 +39,7 @@ var builtIn = configGo{ ID: "built.WithOrchestrionVersion", TracerInternal: true, // This is safe to apply in the tracer itself JoinPoint: join.AllOf( - join.ValueDeclaration(join.MustTypeName("string")), + join.ValueDeclaration(typed.String), join.OneOf( join.DeclarationOf("github.com/DataDog/orchestrion/runtime/built", "WithOrchestrionVersion"), join.Directive("orchestrion:version"), diff --git a/internal/injector/typed/typename.go b/internal/injector/typed/typename.go new file mode 100644 index 000000000..56dbf2eb4 --- /dev/null +++ b/internal/injector/typed/typename.go @@ -0,0 +1,170 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2023-present Datadog, Inc. + +package typed + +import ( + "fmt" + "regexp" + + "github.com/dave/dst" + + "github.com/DataDog/orchestrion/internal/fingerprint" +) + +// Common built-in type definitions for convenience. +// These pre-defined TypeName instances help avoid repeated string literals +// and potential typos when referring to common Go built-in types. +var ( + // Basic types currently used in the codebase + Any = MustTypeName("any") + Bool = MustTypeName("bool") + String = MustTypeName("string") + // Uncomment these when we used. + // Byte = MustTypeName("byte") + // Int = MustTypeName("int") + // Error = MustTypeName("error") +) + +// TypeName represents a parsed Go type name, potentially including a package path and pointer indicator. +type TypeName struct { + // ImportPath is the import Path that provides the type, or an empty string if the + // type is local or built-in (like "error" or "any"). + ImportPath string + // Name is the leaf (un-qualified) name of the type. + Name string + // Pointer determines whether the specified type is a pointer or not. + Pointer bool +} + +// FIXME: this does not support all the type syntax, like: "chan Event" +// It primarily handles identifiers, qualified identifiers, and pointers to those. +var typeNameRe = regexp.MustCompile(`\A(\*)?\s*(?:([A-Za-z_][A-Za-z0-9_.-]+(?:/[A-Za-z_.-][A-Za-z0-9_.-]+)*)\.)?([A-Za-z_][A-Za-z0-9_]*)\z`) + +// NewTypeName parses a string representation of a type name into a TypeName struct. +// It returns an error if the syntax is invalid according to its limited regular expression. +func NewTypeName(n string) (tn TypeName, err error) { + matches := typeNameRe.FindStringSubmatch(n) + if matches == nil { + err = fmt.Errorf("invalid TypeName syntax: %q", n) + return tn, err + } + + tn.Pointer = matches[1] == "*" + tn.ImportPath = matches[2] + tn.Name = matches[3] + return tn, nil +} + +// MustTypeName is the same as NewTypeName, except it panics in case of an error. +func MustTypeName(n string) (tn TypeName) { + var err error + if tn, err = NewTypeName(n); err != nil { + panic(err) + } + return tn +} + +// Matches determines whether the provided AST expression node represents the same type +// as this TypeName. This performs a structural comparison based on the limited types +// supported by the parsing regex (identifiers, selectors, pointers, empty interface). +func (n TypeName) Matches(node dst.Expr) bool { + switch node := node.(type) { + case *dst.Ident: + return !n.Pointer && n.ImportPath == node.Path && n.Name == node.Name + + case *dst.SelectorExpr: + var path string + if ident, ok := node.X.(*dst.Ident); ok && ident.Path == "" { + path = ident.Name + } else { + return false + } + return !n.Pointer && n.ImportPath == path && n.Name == node.Sel.Name + + case *dst.StarExpr: + return n.Pointer && (&TypeName{ImportPath: n.ImportPath, Name: n.Name}).Matches(node.X) + + case *dst.IndexExpr: + // Handle generic types with single type parameter (e.g., MyType[T]) + return !n.Pointer && n.Matches(node.X) + + case *dst.IndexListExpr: + // Handle generic types with multiple type parameters (e.g., MyType[T, U]) + return !n.Pointer && n.Matches(node.X) + + case *dst.InterfaceType: + // We only match the empty interface (as "any") + if len(node.Methods.List) != 0 { + return false + } + return n.ImportPath == "" && n.Name == "any" + + default: + return false + } +} + +// MatchesDefinition determines whether the provided node matches the definition +// of this TypeName. The `importPath` argument determines the context in which +// the assertion is made. +func (n TypeName) MatchesDefinition(node dst.Expr, importPath string) bool { + if n.ImportPath != importPath { + return false + } + return (&TypeName{Name: n.Name, Pointer: n.Pointer}).Matches(node) +} + +// AsNode converts the TypeName back into a dst.Expr AST node. +// Useful for generating code that refers to this type. +func (n *TypeName) AsNode() dst.Expr { + ident := dst.NewIdent(n.Name) + ident.Path = n.ImportPath + if n.Pointer { + return &dst.StarExpr{X: ident} + } + return ident +} + +// Hash contributes the TypeName's properties to a fingerprint hasher. +func (n TypeName) Hash(h *fingerprint.Hasher) error { + return h.Named( + "type-name", + fingerprint.String(n.Name), + fingerprint.String(n.ImportPath), + fingerprint.Bool(n.Pointer), + ) +} + +// FindMatchingTypeName parses a type name string and searches a field list for the first field whose type matches. +// It returns the index of the matching field and whether a match was found. +// The index accounts for fields with multiple names. +func FindMatchingTypeName(fields *dst.FieldList, typeNameStr string) (index int, found bool) { + if fields == nil || len(fields.List) == 0 { + return -1, false + } + + tn, err := NewTypeName(typeNameStr) + if err != nil { + // If the type name string is invalid, we can't match it. + return -1, false + } + + currentIndex := 0 + for _, field := range fields.List { + if tn.Matches(field.Type) { + return currentIndex, true // Found a match. + } + + // Increment index by the number of names in the field (or 1 if unnamed). + count := len(field.Names) + if count == 0 { + count = 1 + } + currentIndex += count + } + + return -1, false // No match found +} diff --git a/internal/injector/aspect/join/join_test.go b/internal/injector/typed/typename_test.go similarity index 99% rename from internal/injector/aspect/join/join_test.go rename to internal/injector/typed/typename_test.go index 14a5bf619..9a849bcda 100644 --- a/internal/injector/aspect/join/join_test.go +++ b/internal/injector/typed/typename_test.go @@ -3,7 +3,7 @@ // This product includes software developed at Datadog (https://www.datadoghq.com/). // Copyright 2023-present Datadog, Inc. -package join +package typed import ( "errors"