Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions _docs/generator/template-funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ import (
"strings"
"unicode"

"golang.org/x/tools/go/packages"

"github.com/DataDog/orchestrion/internal/injector/aspect/advice"
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
"golang.org/x/tools/go/packages"
"github.com/DataDog/orchestrion/internal/injector/typed"
)

var (
Expand Down Expand Up @@ -65,7 +67,7 @@ func render(val any) (template.HTML, error) {

templateName := "doc."
switch val := val.(type) {
case join.Point, join.TypeName, join.FunctionOption:
case join.Point, typed.TypeName, join.FunctionOption:
templateName += "join"
case advice.Advice:
templateName += "advice"
Expand Down
12 changes: 6 additions & 6 deletions internal/injector/aspect/advice/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ import (
"github.com/DataDog/orchestrion/internal/fingerprint"
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
"github.com/DataDog/orchestrion/internal/injector/typed"
"github.com/DataDog/orchestrion/internal/yaml"
"github.com/dave/dst"
"github.com/goccy/go-yaml/ast"
)

type appendArgs struct {
TypeName join.TypeName
TypeName typed.TypeName
Templates []*code.Template
}

// AppendArgs appends arguments of a given type to the end of a function call. All arguments must be
// of the same type, as they may be appended at the tail end of a variadic call.
func AppendArgs(typeName join.TypeName, templates ...*code.Template) *appendArgs {
func AppendArgs(typeName typed.TypeName, templates ...*code.Template) *appendArgs {
return &appendArgs{typeName, templates}
}

Expand Down Expand Up @@ -92,7 +92,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) {
Ellipsis: true,
}

if importPath := a.TypeName.ImportPath(); importPath != "" {
if importPath := a.TypeName.ImportPath; importPath != "" {
ctx.AddImport(importPath, inferPkgName(importPath))
}

Expand All @@ -101,7 +101,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) {

func (a *appendArgs) AddedImports() []string {
imports := make([]string, 0, len(a.Templates)+1)
if argTypeImportPath := a.TypeName.ImportPath(); argTypeImportPath != "" {
if argTypeImportPath := a.TypeName.ImportPath; argTypeImportPath != "" {
imports = append(imports, argTypeImportPath)
}
for _, t := range a.Templates {
Expand Down Expand Up @@ -168,7 +168,7 @@ func init() {
return nil, err
}

tn, err := join.NewTypeName(args.TypeName)
tn, err := typed.NewTypeName(args.TypeName)
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions internal/injector/aspect/advice/call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,31 @@ import (
"github.com/DataDog/orchestrion/internal/injector/aspect/advice"
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
"github.com/DataDog/orchestrion/internal/injector/typed"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAppendArgs(t *testing.T) {
t.Run("AddedImports", func(t *testing.T) {
type testCase struct {
argType join.TypeName
argType typed.TypeName
args []*code.Template
expectedImports []string
}

testCases := map[string]testCase{
"imports-none": {
argType: join.MustTypeName("any"),
argType: typed.Any,
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
},
"imports-from-arg-type": {
argType: join.MustTypeName("*net/http.Request"),
argType: typed.MustTypeName("*net/http.Request"),
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
expectedImports: []string{"net/http"},
},
"imports-from-templates": {
argType: join.MustTypeName("any"),
argType: typed.Any,
args: []*code.Template{
code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/foo"}, context.GoLangVersion{}),
code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/bar"}, context.GoLangVersion{}),
Expand Down
159 changes: 44 additions & 115 deletions internal/injector/aspect/advice/code/dot_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
import (
"errors"
"fmt"
"go/importer"
"go/types"
"strings"

"github.com/DataDog/orchestrion/internal/injector/aspect/context"
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
"github.com/DataDog/orchestrion/internal/injector/typed"
"github.com/dave/dst"
)
Expand Down Expand Up @@ -162,6 +158,11 @@
return "", nil
}

// Optimization: First, check for an exact match using the helper.
if index, found := typed.FindMatchingTypeName(s.Results, name); found {
return fieldAt(s.Results, index, "result")
} // If not found, fall through to type resolution.

// Resolve the interface type.
iface, err := typed.ResolveInterfaceTypeByName(name)
if err != nil {
Expand Down Expand Up @@ -192,18 +193,38 @@
return "", nil
}

// Optimization: First, check for an exact match using TypeName parsing, finding the last one.
lastMatchIndex := -1
if tn, err := typed.NewTypeName(name); err == nil {
currentIndex := 0
for _, field := range s.Results.List {
if tn.Matches(field.Type) {
lastMatchIndex = currentIndex // Update last found index
}
// Increment index by the number of names in the field (or 1 if unnamed).
count := len(field.Names)
if count == 0 {
count = 1
}
currentIndex += count
}
}
// If we found a match via TypeName, return it.
if lastMatchIndex != -1 {
return fieldAt(s.Results, lastMatchIndex, "result")
} // If parsing failed or no match, fall through to type resolution.

// Resolve the interface type.
iface, err := typed.ResolveInterfaceTypeByName(name)
if err != nil {
// Propagate error if interface resolution fails

Check warning on line 220 in internal/injector/aspect/advice/code/dot_function.go

View check run for this annotation

Codecov / codecov/patch

internal/injector/aspect/advice/code/dot_function.go#L220

Added line #L220 was not covered by tests
return "", fmt.Errorf("resolving interface type %q: %w", name, err)
}

// First, we need to build a map of result fields to their indices
// that takes into account named and unnamed parameters.
var (
fieldIndices = make(map[*dst.Field]int)
index = 0
)
// Fallback: Check using ExprImplements, iterating backward.
// Need field indices map again for this path.
fieldIndices := make(map[*dst.Field]int)
index := 0

Check warning on line 227 in internal/injector/aspect/advice/code/dot_function.go

View check run for this annotation

Codecov / codecov/patch

internal/injector/aspect/advice/code/dot_function.go#L226-L227

Added lines #L226 - L227 were not covered by tests
for _, field := range s.Results.List {
fieldIndices[field] = index
count := len(field.Names)
Expand All @@ -213,7 +234,6 @@
index += count
}

// Loop backward through the results list.
for i := len(s.Results.List) - 1; i >= 0; i-- {
field := s.Results.List[i]
if typed.ExprImplements(s.context, field.Type, iface) {
Expand Down Expand Up @@ -265,7 +285,7 @@
}

func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) {
tn, err := join.NewTypeName(typeName)
tn, err := typed.NewTypeName(typeName)
if err != nil {
return "", err
}
Expand All @@ -292,118 +312,27 @@
return "", nil
}

// exprImplements checks if an expression's type implements an interface.
func exprImplements(ctx context.AdviceContext, expr dst.Expr, iface *types.Interface) bool {
actualType := ctx.ResolveType(expr)
if actualType == nil {
return false
}

return typeImplements(actualType, iface)
}

// typeImplements checks if a type implements an interface (including pointer receivers).
func typeImplements(t types.Type, iface *types.Interface) bool {
if t == nil || iface == nil {
return false
}

// Direct implementation check.
if types.Implements(t, iface) {
return true
}

return false
}

// resolveInterfaceTypeByName takes an interface name as a string and resolves it to an interface type.
// It supports built-in interfaces (e.g. "error"), package qualified interfaces (e.g. "io.Reader"),
// and third-party package interfaces (e.g. "example.com/pkg.Interface").
func resolveInterfaceTypeByName(name string) (*types.Interface, error) {
// Handle built-in types.
if obj := types.Universe.Lookup(name); obj != nil {
typeObj, ok := obj.(*types.TypeName)
if !ok {
return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj)
}

typ := typeObj.Type()
if !types.IsInterface(typ) {
return nil, fmt.Errorf("type %s is not an interface", name)
}

t, ok := typ.Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("type %s is not an interface", name)
}

return t, nil
}

// Handle package-qualified types (e.g., "io.Writer").
pkgName, typeName := splitPackageAndName(name)
if pkgName == "" {
return nil, fmt.Errorf("invalid type name: %s", name)
}

// Import the package
imp := importer.Default()
pkg, err := imp.Import(pkgName)
if err != nil {
return nil, fmt.Errorf("failed to import package %q: %w", pkgName, err)
}

// Look up the type in the package's scope
obj := pkg.Scope().Lookup(typeName)
if obj == nil {
return nil, fmt.Errorf("type %q not found in package %q", typeName, pkgName)
}

typeObj, ok := obj.(*types.TypeName)
if !ok {
return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj)
}

typ := typeObj.Type()
if !types.IsInterface(typ) {
return nil, fmt.Errorf("type %s is not an interface", name)
}

t, ok := typ.Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("type %s is not an interface", name)
}

return t, nil
}

// splitPackageAndName splits a fully qualified type name like "io.Reader" or "example.com/pkg.Type"
// into its package path and local name.
// Returns ("", "error") for built-in "error".
// Returns ("", "MyType") for unqualified "MyType".
func splitPackageAndName(fullName string) (pkgPath string, localName string) {
if !strings.Contains(fullName, ".") {
// Assume built-in type (like "error") or unqualified local type.
return "", fullName
}
lastDot := strings.LastIndex(fullName, ".")
pkgPath = fullName[:lastDot]
localName = fullName[lastDot+1:]
return pkgPath, localName
}

// FinalResultImplements returns whether the final result implements the provided interface type.
func (s signature) FinalResultImplements(interfaceName string) (bool, error) {
if s.Results == nil || len(s.Results.List) == 0 {
return false, nil
}

iface, err := resolveInterfaceTypeByName(interfaceName)
lastField := s.Results.List[len(s.Results.List)-1]

// Optimization: First, check for an exact match using TypeName parsing.
// Note: Not using FindMatchingTypeName as we only need to check the last field.
if tn, err := typed.NewTypeName(interfaceName); err == nil {
if tn.Matches(lastField.Type) {
return true, nil
}
} // If parsing failed or no match, fall through to type resolution.

iface, err := typed.ResolveInterfaceTypeByName(interfaceName)
if err != nil {
return false, fmt.Errorf("resolving interface type %q: %w", interfaceName, err)
}

// Check if the last field type implements the interface.
lastField := s.Results.List[len(s.Results.List)-1]
return exprImplements(s.context, lastField.Type, iface), nil
return typed.ExprImplements(s.context, lastField.Type, iface), nil
}
12 changes: 6 additions & 6 deletions internal/injector/aspect/advice/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@ import (

"github.com/DataDog/orchestrion/internal/fingerprint"
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
"github.com/DataDog/orchestrion/internal/injector/typed"
"github.com/DataDog/orchestrion/internal/yaml"
"github.com/dave/dst"
"github.com/goccy/go-yaml/ast"
)

type addStructField struct {
Name string
TypeName join.TypeName
TypeName typed.TypeName
}

// AddStructField adds a new synthetic field at the tail end of a struct declaration.
func AddStructField(fieldName string, fieldType join.TypeName) *addStructField {
func AddStructField(fieldName string, fieldType typed.TypeName) *addStructField {
return &addStructField{fieldName, fieldType}
}

Expand All @@ -47,7 +47,7 @@ func (a *addStructField) Apply(ctx context.AdviceContext) (bool, error) {
Type: a.TypeName.AsNode(),
})

if importPath := a.TypeName.ImportPath(); importPath != "" {
if importPath := a.TypeName.ImportPath; importPath != "" {
// If the type name is qualified, we may need to import the package, too.
_ = ctx.AddImport(importPath, inferPkgName(importPath))
}
Expand All @@ -60,7 +60,7 @@ func (a *addStructField) Hash(h *fingerprint.Hasher) error {
}

func (a *addStructField) AddedImports() []string {
if path := a.TypeName.ImportPath(); path != "" {
if path := a.TypeName.ImportPath; path != "" {
return []string{path}
}
return nil
Expand All @@ -76,7 +76,7 @@ func init() {
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
return nil, err
}
tn, err := join.NewTypeName(spec.Type)
tn, err := typed.NewTypeName(spec.Type)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading