@@ -8,14 +8,10 @@ package code
88import (
99 "errors"
1010 "fmt"
11- "go/importer"
12- "go/types"
13- "strings"
1411
1512 "github.com/dave/dst"
1613
1714 "github.com/DataDog/orchestrion/internal/injector/aspect/context"
18- "github.com/DataDog/orchestrion/internal/injector/aspect/join"
1915 "github.com/DataDog/orchestrion/internal/injector/typed"
2016)
2117
@@ -162,6 +158,11 @@ func (s signature) ResultThatImplements(name string) (string, error) {
162158 return "" , nil
163159 }
164160
161+ // Optimization: First, check for an exact match using the helper.
162+ if index , found := typed .FindMatchingTypeName (s .Results , name ); found {
163+ return fieldAt (s .Results , index , "result" )
164+ } // If not found, fall through to type resolution.
165+
165166 // Resolve the interface type.
166167 iface , err := typed .ResolveInterfaceTypeByName (name )
167168 if err != nil {
@@ -192,18 +193,38 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
192193 return "" , nil
193194 }
194195
196+ // Optimization: First, check for an exact match using TypeName parsing, finding the last one.
197+ lastMatchIndex := - 1
198+ if tn , err := typed .NewTypeName (name ); err == nil {
199+ currentIndex := 0
200+ for _ , field := range s .Results .List {
201+ if tn .Matches (field .Type ) {
202+ lastMatchIndex = currentIndex // Update last found index
203+ }
204+ // Increment index by the number of names in the field (or 1 if unnamed).
205+ count := len (field .Names )
206+ if count == 0 {
207+ count = 1
208+ }
209+ currentIndex += count
210+ }
211+ }
212+ // If we found a match via TypeName, return it.
213+ if lastMatchIndex != - 1 {
214+ return fieldAt (s .Results , lastMatchIndex , "result" )
215+ } // If parsing failed or no match, fall through to type resolution.
216+
195217 // Resolve the interface type.
196218 iface , err := typed .ResolveInterfaceTypeByName (name )
197219 if err != nil {
220+ // Propagate error if interface resolution fails
198221 return "" , fmt .Errorf ("resolving interface type %q: %w" , name , err )
199222 }
200223
201- // First, we need to build a map of result fields to their indices
202- // that takes into account named and unnamed parameters.
203- var (
204- fieldIndices = make (map [* dst.Field ]int )
205- index = 0
206- )
224+ // Fallback: Check using ExprImplements, iterating backward.
225+ // Need field indices map again for this path.
226+ fieldIndices := make (map [* dst.Field ]int )
227+ index := 0
207228 for _ , field := range s .Results .List {
208229 fieldIndices [field ] = index
209230 count := len (field .Names )
@@ -213,7 +234,6 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
213234 index += count
214235 }
215236
216- // Loop backward through the results list.
217237 for i := len (s .Results .List ) - 1 ; i >= 0 ; i -- {
218238 field := s .Results .List [i ]
219239 if typed .ExprImplements (s .context , field .Type , iface ) {
@@ -265,7 +285,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
265285}
266286
267287func fieldOfType (fields * dst.FieldList , typeName string , use string ) (string , error ) {
268- tn , err := join .NewTypeName (typeName )
288+ tn , err := typed .NewTypeName (typeName )
269289 if err != nil {
270290 return "" , err
271291 }
@@ -292,118 +312,27 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er
292312 return "" , nil
293313}
294314
295- // exprImplements checks if an expression's type implements an interface.
296- func exprImplements (ctx context.AdviceContext , expr dst.Expr , iface * types.Interface ) bool {
297- actualType := ctx .ResolveType (expr )
298- if actualType == nil {
299- return false
300- }
301-
302- return typeImplements (actualType , iface )
303- }
304-
305- // typeImplements checks if a type implements an interface (including pointer receivers).
306- func typeImplements (t types.Type , iface * types.Interface ) bool {
307- if t == nil || iface == nil {
308- return false
309- }
310-
311- // Direct implementation check.
312- if types .Implements (t , iface ) {
313- return true
314- }
315-
316- return false
317- }
318-
319- // resolveInterfaceTypeByName takes an interface name as a string and resolves it to an interface type.
320- // It supports built-in interfaces (e.g. "error"), package qualified interfaces (e.g. "io.Reader"),
321- // and third-party package interfaces (e.g. "example.com/pkg.Interface").
322- func resolveInterfaceTypeByName (name string ) (* types.Interface , error ) {
323- // Handle built-in types.
324- if obj := types .Universe .Lookup (name ); obj != nil {
325- typeObj , ok := obj .(* types.TypeName )
326- if ! ok {
327- return nil , fmt .Errorf ("object %s is not a type name but a %T" , name , obj )
328- }
329-
330- typ := typeObj .Type ()
331- if ! types .IsInterface (typ ) {
332- return nil , fmt .Errorf ("type %s is not an interface" , name )
333- }
334-
335- t , ok := typ .Underlying ().(* types.Interface )
336- if ! ok {
337- return nil , fmt .Errorf ("type %s is not an interface" , name )
338- }
339-
340- return t , nil
341- }
342-
343- // Handle package-qualified types (e.g., "io.Writer").
344- pkgName , typeName := splitPackageAndName (name )
345- if pkgName == "" {
346- return nil , fmt .Errorf ("invalid type name: %s" , name )
347- }
348-
349- // Import the package
350- imp := importer .Default ()
351- pkg , err := imp .Import (pkgName )
352- if err != nil {
353- return nil , fmt .Errorf ("failed to import package %q: %w" , pkgName , err )
354- }
355-
356- // Look up the type in the package's scope
357- obj := pkg .Scope ().Lookup (typeName )
358- if obj == nil {
359- return nil , fmt .Errorf ("type %q not found in package %q" , typeName , pkgName )
360- }
361-
362- typeObj , ok := obj .(* types.TypeName )
363- if ! ok {
364- return nil , fmt .Errorf ("object %s is not a type name but a %T" , name , obj )
365- }
366-
367- typ := typeObj .Type ()
368- if ! types .IsInterface (typ ) {
369- return nil , fmt .Errorf ("type %s is not an interface" , name )
370- }
371-
372- t , ok := typ .Underlying ().(* types.Interface )
373- if ! ok {
374- return nil , fmt .Errorf ("type %s is not an interface" , name )
375- }
376-
377- return t , nil
378- }
379-
380- // splitPackageAndName splits a fully qualified type name like "io.Reader" or "example.com/pkg.Type"
381- // into its package path and local name.
382- // Returns ("", "error") for built-in "error".
383- // Returns ("", "MyType") for unqualified "MyType".
384- func splitPackageAndName (fullName string ) (pkgPath string , localName string ) {
385- if ! strings .Contains (fullName , "." ) {
386- // Assume built-in type (like "error") or unqualified local type.
387- return "" , fullName
388- }
389- lastDot := strings .LastIndex (fullName , "." )
390- pkgPath = fullName [:lastDot ]
391- localName = fullName [lastDot + 1 :]
392- return pkgPath , localName
393- }
394-
395315// FinalResultImplements returns whether the final result implements the provided interface type.
396316func (s signature ) FinalResultImplements (interfaceName string ) (bool , error ) {
397317 if s .Results == nil || len (s .Results .List ) == 0 {
398318 return false , nil
399319 }
400320
401- iface , err := resolveInterfaceTypeByName (interfaceName )
321+ lastField := s .Results .List [len (s .Results .List )- 1 ]
322+
323+ // Optimization: First, check for an exact match using TypeName parsing.
324+ // Note: Not using FindMatchingTypeName as we only need to check the last field.
325+ if tn , err := typed .NewTypeName (interfaceName ); err == nil {
326+ if tn .Matches (lastField .Type ) {
327+ return true , nil
328+ }
329+ } // If parsing failed or no match, fall through to type resolution.
330+
331+ iface , err := typed .ResolveInterfaceTypeByName (interfaceName )
402332 if err != nil {
403333 return false , fmt .Errorf ("resolving interface type %q: %w" , interfaceName , err )
404334 }
405335
406336 // Check if the last field type implements the interface.
407- lastField := s .Results .List [len (s .Results .List )- 1 ]
408- return exprImplements (s .context , lastField .Type , iface ), nil
337+ return typed .ExprImplements (s .context , lastField .Type , iface ), nil
409338}
0 commit comments