Skip to content

Commit c56bf17

Browse files
committed
refactor(injector): Refactor TypeName handling and migrate to typed package
This commit refactors the handling of TypeName across the injector package, transitioning from the join package to a new typed package. The changes include: - Replacing instances of join.TypeName with typed.TypeName in various files, ensuring consistent type handling. - Updating related functions and methods to accommodate the new TypeName structure, including adjustments to import path retrieval and pointer handling. - Adding comprehensive tests for the new TypeName implementation to validate its functionality and error handling. This refactor enhances code clarity and maintainability by centralizing type name logic within the typed package. Signed-off-by: Kemal Akkoyun <kemal.akkoyun@datadoghq.com>
1 parent 5af6ebb commit c56bf17

12 files changed

Lines changed: 308 additions & 320 deletions

File tree

internal/injector/aspect/advice/call.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@ import (
1414
"github.com/DataDog/orchestrion/internal/fingerprint"
1515
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
1616
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
17-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
17+
"github.com/DataDog/orchestrion/internal/injector/typed"
1818
"github.com/DataDog/orchestrion/internal/yaml"
1919
"github.com/dave/dst"
2020
"github.com/goccy/go-yaml/ast"
2121
)
2222

2323
type appendArgs struct {
24-
TypeName join.TypeName
24+
TypeName typed.TypeName
2525
Templates []*code.Template
2626
}
2727

2828
// AppendArgs appends arguments of a given type to the end of a function call. All arguments must be
2929
// of the same type, as they may be appended at the tail end of a variadic call.
30-
func AppendArgs(typeName join.TypeName, templates ...*code.Template) *appendArgs {
30+
func AppendArgs(typeName typed.TypeName, templates ...*code.Template) *appendArgs {
3131
return &appendArgs{typeName, templates}
3232
}
3333

@@ -92,7 +92,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) {
9292
Ellipsis: true,
9393
}
9494

95-
if importPath := a.TypeName.ImportPath(); importPath != "" {
95+
if importPath := a.TypeName.ImportPath; importPath != "" {
9696
ctx.AddImport(importPath, inferPkgName(importPath))
9797
}
9898

@@ -101,7 +101,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) {
101101

102102
func (a *appendArgs) AddedImports() []string {
103103
imports := make([]string, 0, len(a.Templates)+1)
104-
if argTypeImportPath := a.TypeName.ImportPath(); argTypeImportPath != "" {
104+
if argTypeImportPath := a.TypeName.ImportPath; argTypeImportPath != "" {
105105
imports = append(imports, argTypeImportPath)
106106
}
107107
for _, t := range a.Templates {
@@ -168,7 +168,7 @@ func init() {
168168
return nil, err
169169
}
170170

171-
tn, err := join.NewTypeName(args.TypeName)
171+
tn, err := typed.NewTypeName(args.TypeName)
172172
if err != nil {
173173
return nil, err
174174
}

internal/injector/aspect/advice/call_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,31 @@ import (
1111
"github.com/DataDog/orchestrion/internal/injector/aspect/advice"
1212
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
1313
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
14-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
14+
"github.com/DataDog/orchestrion/internal/injector/typed"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
1717
)
1818

1919
func TestAppendArgs(t *testing.T) {
2020
t.Run("AddedImports", func(t *testing.T) {
2121
type testCase struct {
22-
argType join.TypeName
22+
argType typed.TypeName
2323
args []*code.Template
2424
expectedImports []string
2525
}
2626

2727
testCases := map[string]testCase{
2828
"imports-none": {
29-
argType: join.MustTypeName("any"),
29+
argType: typed.MustTypeName("any"),
3030
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3131
},
3232
"imports-from-arg-type": {
33-
argType: join.MustTypeName("*net/http.Request"),
33+
argType: typed.MustTypeName("*net/http.Request"),
3434
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3535
expectedImports: []string{"net/http"},
3636
},
3737
"imports-from-templates": {
38-
argType: join.MustTypeName("any"),
38+
argType: typed.MustTypeName("any"),
3939
args: []*code.Template{
4040
code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/foo"}, context.GoLangVersion{}),
4141
code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/bar"}, context.GoLangVersion{}),

internal/injector/aspect/advice/code/dot_function.go

Lines changed: 44 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,10 @@ package code
88
import (
99
"errors"
1010
"fmt"
11-
"go/importer"
12-
"go/types"
13-
"strings"
1411

1512
"github.com/dave/dst"
1613

1714
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
18-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
1915
"github.com/DataDog/orchestrion/internal/injector/typed"
2016
)
2117

@@ -162,6 +158,11 @@ func (s signature) ResultThatImplements(name string) (string, error) {
162158
return "", nil
163159
}
164160

161+
// Optimization: First, check for an exact match using the helper.
162+
if index, found := typed.FindMatchingTypeName(s.Results, name); found {
163+
return fieldAt(s.Results, index, "result")
164+
} // If not found, fall through to type resolution.
165+
165166
// Resolve the interface type.
166167
iface, err := typed.ResolveInterfaceTypeByName(name)
167168
if err != nil {
@@ -192,18 +193,38 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
192193
return "", nil
193194
}
194195

196+
// Optimization: First, check for an exact match using TypeName parsing, finding the last one.
197+
lastMatchIndex := -1
198+
if tn, err := typed.NewTypeName(name); err == nil {
199+
currentIndex := 0
200+
for _, field := range s.Results.List {
201+
if tn.Matches(field.Type) {
202+
lastMatchIndex = currentIndex // Update last found index
203+
}
204+
// Increment index by the number of names in the field (or 1 if unnamed).
205+
count := len(field.Names)
206+
if count == 0 {
207+
count = 1
208+
}
209+
currentIndex += count
210+
}
211+
}
212+
// If we found a match via TypeName, return it.
213+
if lastMatchIndex != -1 {
214+
return fieldAt(s.Results, lastMatchIndex, "result")
215+
} // If parsing failed or no match, fall through to type resolution.
216+
195217
// Resolve the interface type.
196218
iface, err := typed.ResolveInterfaceTypeByName(name)
197219
if err != nil {
220+
// Propagate error if interface resolution fails
198221
return "", fmt.Errorf("resolving interface type %q: %w", name, err)
199222
}
200223

201-
// First, we need to build a map of result fields to their indices
202-
// that takes into account named and unnamed parameters.
203-
var (
204-
fieldIndices = make(map[*dst.Field]int)
205-
index = 0
206-
)
224+
// Fallback: Check using ExprImplements, iterating backward.
225+
// Need field indices map again for this path.
226+
fieldIndices := make(map[*dst.Field]int)
227+
index := 0
207228
for _, field := range s.Results.List {
208229
fieldIndices[field] = index
209230
count := len(field.Names)
@@ -213,7 +234,6 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
213234
index += count
214235
}
215236

216-
// Loop backward through the results list.
217237
for i := len(s.Results.List) - 1; i >= 0; i-- {
218238
field := s.Results.List[i]
219239
if typed.ExprImplements(s.context, field.Type, iface) {
@@ -265,7 +285,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
265285
}
266286

267287
func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) {
268-
tn, err := join.NewTypeName(typeName)
288+
tn, err := typed.NewTypeName(typeName)
269289
if err != nil {
270290
return "", err
271291
}
@@ -292,118 +312,27 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er
292312
return "", nil
293313
}
294314

295-
// exprImplements checks if an expression's type implements an interface.
296-
func exprImplements(ctx context.AdviceContext, expr dst.Expr, iface *types.Interface) bool {
297-
actualType := ctx.ResolveType(expr)
298-
if actualType == nil {
299-
return false
300-
}
301-
302-
return typeImplements(actualType, iface)
303-
}
304-
305-
// typeImplements checks if a type implements an interface (including pointer receivers).
306-
func typeImplements(t types.Type, iface *types.Interface) bool {
307-
if t == nil || iface == nil {
308-
return false
309-
}
310-
311-
// Direct implementation check.
312-
if types.Implements(t, iface) {
313-
return true
314-
}
315-
316-
return false
317-
}
318-
319-
// resolveInterfaceTypeByName takes an interface name as a string and resolves it to an interface type.
320-
// It supports built-in interfaces (e.g. "error"), package qualified interfaces (e.g. "io.Reader"),
321-
// and third-party package interfaces (e.g. "example.com/pkg.Interface").
322-
func resolveInterfaceTypeByName(name string) (*types.Interface, error) {
323-
// Handle built-in types.
324-
if obj := types.Universe.Lookup(name); obj != nil {
325-
typeObj, ok := obj.(*types.TypeName)
326-
if !ok {
327-
return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj)
328-
}
329-
330-
typ := typeObj.Type()
331-
if !types.IsInterface(typ) {
332-
return nil, fmt.Errorf("type %s is not an interface", name)
333-
}
334-
335-
t, ok := typ.Underlying().(*types.Interface)
336-
if !ok {
337-
return nil, fmt.Errorf("type %s is not an interface", name)
338-
}
339-
340-
return t, nil
341-
}
342-
343-
// Handle package-qualified types (e.g., "io.Writer").
344-
pkgName, typeName := splitPackageAndName(name)
345-
if pkgName == "" {
346-
return nil, fmt.Errorf("invalid type name: %s", name)
347-
}
348-
349-
// Import the package
350-
imp := importer.Default()
351-
pkg, err := imp.Import(pkgName)
352-
if err != nil {
353-
return nil, fmt.Errorf("failed to import package %q: %w", pkgName, err)
354-
}
355-
356-
// Look up the type in the package's scope
357-
obj := pkg.Scope().Lookup(typeName)
358-
if obj == nil {
359-
return nil, fmt.Errorf("type %q not found in package %q", typeName, pkgName)
360-
}
361-
362-
typeObj, ok := obj.(*types.TypeName)
363-
if !ok {
364-
return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj)
365-
}
366-
367-
typ := typeObj.Type()
368-
if !types.IsInterface(typ) {
369-
return nil, fmt.Errorf("type %s is not an interface", name)
370-
}
371-
372-
t, ok := typ.Underlying().(*types.Interface)
373-
if !ok {
374-
return nil, fmt.Errorf("type %s is not an interface", name)
375-
}
376-
377-
return t, nil
378-
}
379-
380-
// splitPackageAndName splits a fully qualified type name like "io.Reader" or "example.com/pkg.Type"
381-
// into its package path and local name.
382-
// Returns ("", "error") for built-in "error".
383-
// Returns ("", "MyType") for unqualified "MyType".
384-
func splitPackageAndName(fullName string) (pkgPath string, localName string) {
385-
if !strings.Contains(fullName, ".") {
386-
// Assume built-in type (like "error") or unqualified local type.
387-
return "", fullName
388-
}
389-
lastDot := strings.LastIndex(fullName, ".")
390-
pkgPath = fullName[:lastDot]
391-
localName = fullName[lastDot+1:]
392-
return pkgPath, localName
393-
}
394-
395315
// FinalResultImplements returns whether the final result implements the provided interface type.
396316
func (s signature) FinalResultImplements(interfaceName string) (bool, error) {
397317
if s.Results == nil || len(s.Results.List) == 0 {
398318
return false, nil
399319
}
400320

401-
iface, err := resolveInterfaceTypeByName(interfaceName)
321+
lastField := s.Results.List[len(s.Results.List)-1]
322+
323+
// Optimization: First, check for an exact match using TypeName parsing.
324+
// Note: Not using FindMatchingTypeName as we only need to check the last field.
325+
if tn, err := typed.NewTypeName(interfaceName); err == nil {
326+
if tn.Matches(lastField.Type) {
327+
return true, nil
328+
}
329+
} // If parsing failed or no match, fall through to type resolution.
330+
331+
iface, err := typed.ResolveInterfaceTypeByName(interfaceName)
402332
if err != nil {
403333
return false, fmt.Errorf("resolving interface type %q: %w", interfaceName, err)
404334
}
405335

406336
// Check if the last field type implements the interface.
407-
lastField := s.Results.List[len(s.Results.List)-1]
408-
return exprImplements(s.context, lastField.Type, iface), nil
337+
return typed.ExprImplements(s.context, lastField.Type, iface), nil
409338
}

internal/injector/aspect/advice/struct.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@ import (
1111

1212
"github.com/DataDog/orchestrion/internal/fingerprint"
1313
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
14-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
14+
"github.com/DataDog/orchestrion/internal/injector/typed"
1515
"github.com/DataDog/orchestrion/internal/yaml"
1616
"github.com/dave/dst"
1717
"github.com/goccy/go-yaml/ast"
1818
)
1919

2020
type addStructField struct {
2121
Name string
22-
TypeName join.TypeName
22+
TypeName typed.TypeName
2323
}
2424

2525
// AddStructField adds a new synthetic field at the tail end of a struct declaration.
26-
func AddStructField(fieldName string, fieldType join.TypeName) *addStructField {
26+
func AddStructField(fieldName string, fieldType typed.TypeName) *addStructField {
2727
return &addStructField{fieldName, fieldType}
2828
}
2929

@@ -47,7 +47,7 @@ func (a *addStructField) Apply(ctx context.AdviceContext) (bool, error) {
4747
Type: a.TypeName.AsNode(),
4848
})
4949

50-
if importPath := a.TypeName.ImportPath(); importPath != "" {
50+
if importPath := a.TypeName.ImportPath; importPath != "" {
5151
// If the type name is qualified, we may need to import the package, too.
5252
_ = ctx.AddImport(importPath, inferPkgName(importPath))
5353
}
@@ -60,7 +60,7 @@ func (a *addStructField) Hash(h *fingerprint.Hasher) error {
6060
}
6161

6262
func (a *addStructField) AddedImports() []string {
63-
if path := a.TypeName.ImportPath(); path != "" {
63+
if path := a.TypeName.ImportPath; path != "" {
6464
return []string{path}
6565
}
6666
return nil
@@ -76,7 +76,7 @@ func init() {
7676
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
7777
return nil, err
7878
}
79-
tn, err := join.NewTypeName(spec.Type)
79+
tn, err := typed.NewTypeName(spec.Type)
8080
if err != nil {
8181
return nil, err
8282
}

0 commit comments

Comments
 (0)