Skip to content

Commit 920342e

Browse files
authored
feat(internal/injector): add method-call join point (#829)
Add new `method-call` join point that matches calls by the receiver's resolved type. Example: ```yaml join-point: method-call: receiver: "go.uber.org/zap.Logger" # fully qualified type name (package import path + type name) without a pointer sigil. name: Info # the method to match. match: any # possible values are: any (default), pointer-only, or value-only. ``` Needed for DataDog/dd-trace-go#4729
1 parent 68a4149 commit 920342e

8 files changed

Lines changed: 607 additions & 7 deletions

File tree

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
// Unless explicitly stated otherwise all files in this repository are licensed
2+
// under the Apache License Version 2.0.
3+
// This product includes software developed at Datadog (https://www.datadoghq.com/).
4+
// Copyright 2023-present Datadog, Inc.
5+
6+
package join
7+
8+
import (
9+
gocontext "context"
10+
"errors"
11+
"fmt"
12+
"go/types"
13+
14+
"github.com/DataDog/orchestrion/internal/fingerprint"
15+
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
16+
"github.com/DataDog/orchestrion/internal/injector/aspect/may"
17+
"github.com/DataDog/orchestrion/internal/injector/typed"
18+
"github.com/DataDog/orchestrion/internal/yaml"
19+
"github.com/dave/dst"
20+
"github.com/goccy/go-yaml/ast"
21+
)
22+
23+
type (
24+
MethodCallMatch int
25+
26+
methodCall struct {
27+
Receiver typed.TypeName
28+
Name string
29+
Match MethodCallMatch
30+
}
31+
)
32+
33+
const (
34+
// MethodCallMatchAny matches calls regardless of whether the receiver is a pointer or value. This is the default.
35+
MethodCallMatchAny MethodCallMatch = iota
36+
// MethodCallMatchPointerOnly matches only calls where the receiver is a pointer type.
37+
MethodCallMatchPointerOnly
38+
// MethodCallMatchValueOnly matches only calls where the receiver is a value type.
39+
MethodCallMatchValueOnly
40+
)
41+
42+
func MethodCall(receiver typed.TypeName, name string, match MethodCallMatch) *methodCall {
43+
return &methodCall{Receiver: receiver, Name: name, Match: match}
44+
}
45+
46+
func (m *methodCall) ImpliesImported() []string {
47+
if path := m.Receiver.ImportPath; path != "" {
48+
return []string{path}
49+
}
50+
return nil
51+
}
52+
53+
func (m *methodCall) PackageMayMatch(ctx *may.PackageContext) may.MatchType {
54+
return ctx.PackageImports(m.Receiver.ImportPath)
55+
}
56+
57+
func (m *methodCall) FileMayMatch(ctx *may.FileContext) may.MatchType {
58+
return ctx.FileContains(m.Name)
59+
}
60+
61+
func (m *methodCall) Matches(ctx context.AspectContext) bool {
62+
call, ok := ctx.Node().(*dst.CallExpr)
63+
if !ok {
64+
return false
65+
}
66+
67+
selector, ok := call.Fun.(*dst.SelectorExpr)
68+
if !ok || selector.Sel.Name != m.Name {
69+
return false
70+
}
71+
72+
recvType := ctx.ResolveType(selector.X)
73+
return m.matchesType(recvType)
74+
}
75+
76+
func (m *methodCall) matchesType(t types.Type) bool {
77+
if t == nil {
78+
return false
79+
}
80+
81+
switch m.Match {
82+
case MethodCallMatchPointerOnly:
83+
ptr, ok := t.(*types.Pointer)
84+
if !ok {
85+
return false
86+
}
87+
return m.matchesNamed(ptr.Elem())
88+
case MethodCallMatchValueOnly:
89+
return m.matchesNamed(t)
90+
default: // MethodCallMatchAny
91+
if ptr, ok := t.(*types.Pointer); ok {
92+
t = ptr.Elem()
93+
}
94+
return m.matchesNamed(t)
95+
}
96+
}
97+
98+
func (m *methodCall) matchesNamed(t types.Type) bool {
99+
named, ok := t.(*types.Named)
100+
if !ok {
101+
return false
102+
}
103+
obj := named.Obj()
104+
return obj.Pkg() != nil &&
105+
obj.Pkg().Path() == m.Receiver.ImportPath &&
106+
obj.Name() == m.Receiver.Name
107+
}
108+
109+
func (m *methodCall) Hash(h *fingerprint.Hasher) error {
110+
return h.Named("method-call", m.Receiver, fingerprint.String(m.Name), m.Match)
111+
}
112+
113+
func init() {
114+
unmarshalers["method-call"] = func(ctx gocontext.Context, node ast.Node) (Point, error) {
115+
var spec struct {
116+
Receiver string `yaml:"receiver"`
117+
Name string `yaml:"name"`
118+
Match MethodCallMatch `yaml:"match"`
119+
}
120+
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
121+
return nil, err
122+
}
123+
124+
if spec.Receiver == "" {
125+
return nil, errors.New("method-call: missing required field 'receiver'")
126+
}
127+
if spec.Name == "" {
128+
return nil, errors.New("method-call: missing required field 'name'")
129+
}
130+
131+
tn, err := typed.NewTypeName(spec.Receiver)
132+
if err != nil {
133+
return nil, fmt.Errorf("method-call: invalid receiver type %q: %w", spec.Receiver, err)
134+
}
135+
if tn.Pointer {
136+
return nil, fmt.Errorf("method-call: receiver type must not include a pointer sigil (use match: pointer-only instead): %q", spec.Receiver)
137+
}
138+
139+
return MethodCall(tn, spec.Name, spec.Match), nil
140+
}
141+
}
142+
143+
var _ yaml.NodeUnmarshalerContext = (*MethodCallMatch)(nil)
144+
145+
func (m *MethodCallMatch) UnmarshalYAML(ctx gocontext.Context, node ast.Node) error {
146+
var name string
147+
if err := yaml.NodeToValueContext(ctx, node, &name); err != nil {
148+
return err
149+
}
150+
151+
switch name {
152+
case "any", "":
153+
*m = MethodCallMatchAny
154+
case "pointer-only":
155+
*m = MethodCallMatchPointerOnly
156+
case "value-only":
157+
*m = MethodCallMatchValueOnly
158+
default:
159+
return fmt.Errorf("invalid method-call.match value: %q", name)
160+
}
161+
162+
return nil
163+
}
164+
165+
func (m MethodCallMatch) String() string {
166+
switch m {
167+
case MethodCallMatchAny:
168+
return "any"
169+
case MethodCallMatchPointerOnly:
170+
return "pointer-only"
171+
case MethodCallMatchValueOnly:
172+
return "value-only"
173+
default:
174+
panic(fmt.Errorf("invalid MethodCallMatch(%d)", int(m)))
175+
}
176+
}
177+
178+
func (m MethodCallMatch) Hash(h *fingerprint.Hasher) error {
179+
return h.Named("method-call-match", fingerprint.Int(m))
180+
}
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
// Unless explicitly stated otherwise all files in this repository are licensed
2+
// under the Apache License Version 2.0.
3+
// This product includes software developed at Datadog (https://www.datadoghq.com/).
4+
// Copyright 2023-present Datadog, Inc.
5+
6+
package join
7+
8+
import (
9+
gocontext "context"
10+
"go/types"
11+
"testing"
12+
13+
"github.com/goccy/go-yaml"
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
17+
"github.com/DataDog/orchestrion/internal/fingerprint"
18+
"github.com/DataDog/orchestrion/internal/injector/aspect/may"
19+
"github.com/DataDog/orchestrion/internal/injector/typed"
20+
)
21+
22+
func newNamedType(pkgPath string, pkgName string, typeName string) *types.Named {
23+
pkg := types.NewPackage(pkgPath, pkgName)
24+
obj := types.NewTypeName(0, pkg, typeName, nil)
25+
return types.NewNamed(obj, types.NewStruct(nil, nil), nil)
26+
}
27+
28+
func TestMethodCallMatchesType(t *testing.T) {
29+
zapLogger := newNamedType("go.uber.org/zap", "zap", "Logger")
30+
zapLoggerPtr := types.NewPointer(zapLogger)
31+
otherLogger := newNamedType("example.com/other", "other", "Logger")
32+
otherLoggerPtr := types.NewPointer(otherLogger)
33+
34+
tn, err := typed.NewTypeName("go.uber.org/zap.Logger")
35+
require.NoError(t, err)
36+
37+
tests := []struct {
38+
name string
39+
match MethodCallMatch
40+
typ types.Type
41+
want bool
42+
}{
43+
{name: "any: matches pointer", match: MethodCallMatchAny, typ: zapLoggerPtr, want: true},
44+
{name: "any: matches value", match: MethodCallMatchAny, typ: zapLogger, want: true},
45+
{name: "pointer-only: matches pointer", match: MethodCallMatchPointerOnly, typ: zapLoggerPtr, want: true},
46+
{name: "pointer-only: rejects value", match: MethodCallMatchPointerOnly, typ: zapLogger, want: false},
47+
{name: "value-only: matches value", match: MethodCallMatchValueOnly, typ: zapLogger, want: true},
48+
{name: "value-only: rejects pointer", match: MethodCallMatchValueOnly, typ: zapLoggerPtr, want: false},
49+
{name: "any: rejects different package", match: MethodCallMatchAny, typ: otherLoggerPtr, want: false},
50+
{name: "any: rejects nil", match: MethodCallMatchAny, typ: nil, want: false},
51+
}
52+
53+
for _, tt := range tests {
54+
t.Run(tt.name, func(t *testing.T) {
55+
m := MethodCall(tn, "Info", tt.match)
56+
assert.Equal(t, tt.want, m.matchesType(tt.typ))
57+
})
58+
}
59+
}
60+
61+
func TestMethodCallPackageMayMatch(t *testing.T) {
62+
tn, err := typed.NewTypeName("go.uber.org/zap.Logger")
63+
require.NoError(t, err)
64+
m := MethodCall(tn, "Info", MethodCallMatchAny)
65+
66+
importing := &may.PackageContext{ImportMap: map[string]string{"go.uber.org/zap": "zap.a"}}
67+
notImporting := &may.PackageContext{ImportMap: map[string]string{"example.com/other": "other.a"}}
68+
69+
assert.Equal(t, may.Match, m.PackageMayMatch(importing))
70+
assert.Equal(t, may.NeverMatch, m.PackageMayMatch(notImporting))
71+
}
72+
73+
func TestMethodCallFileMayMatch(t *testing.T) {
74+
tn, err := typed.NewTypeName("go.uber.org/zap.Logger")
75+
require.NoError(t, err)
76+
m := MethodCall(tn, "Info", MethodCallMatchAny)
77+
78+
withMethod := &may.FileContext{FileContent: []byte(`package main; func f() { logger.Info("hi") }`)}
79+
withoutMethod := &may.FileContext{FileContent: []byte(`package main; func f() { logger.Debug("hi") }`)}
80+
81+
assert.Equal(t, may.Match, m.FileMayMatch(withMethod))
82+
assert.Equal(t, may.NeverMatch, m.FileMayMatch(withoutMethod))
83+
}
84+
85+
func TestMethodCallImpliesImported(t *testing.T) {
86+
tn, err := typed.NewTypeName("go.uber.org/zap.Logger")
87+
require.NoError(t, err)
88+
m := MethodCall(tn, "Info", MethodCallMatchAny)
89+
assert.Equal(t, []string{"go.uber.org/zap"}, m.ImpliesImported())
90+
}
91+
92+
func TestMethodCallHash(t *testing.T) {
93+
tn1, _ := typed.NewTypeName("go.uber.org/zap.Logger")
94+
tn2, _ := typed.NewTypeName("go.uber.org/zap.Logger")
95+
tn3, _ := typed.NewTypeName("example.com/other.Logger")
96+
97+
m1 := MethodCall(tn1, "Info", MethodCallMatchAny)
98+
m2 := MethodCall(tn2, "Info", MethodCallMatchAny)
99+
m3 := MethodCall(tn3, "Info", MethodCallMatchAny)
100+
m4 := MethodCall(tn1, "Debug", MethodCallMatchAny)
101+
m5 := MethodCall(tn1, "Info", MethodCallMatchPointerOnly)
102+
103+
hash := func(m *methodCall) string {
104+
h := fingerprint.New()
105+
require.NoError(t, m.Hash(h))
106+
return h.Finish()
107+
}
108+
109+
assert.Equal(t, hash(m1), hash(m2), "identical method-calls must hash equally")
110+
assert.NotEqual(t, hash(m1), hash(m3), "different receiver packages must hash differently")
111+
assert.NotEqual(t, hash(m1), hash(m4), "different method names must hash differently")
112+
assert.NotEqual(t, hash(m1), hash(m5), "different match modes must hash differently")
113+
}
114+
115+
func TestMethodCallUnmarshalYAML(t *testing.T) {
116+
tests := []struct {
117+
name string
118+
yaml string
119+
wantImport string
120+
wantType string
121+
wantMethod string
122+
wantMatch MethodCallMatch
123+
wantErr bool
124+
}{
125+
{
126+
name: "defaults to any",
127+
yaml: `method-call:
128+
receiver: "go.uber.org/zap.Logger"
129+
name: Info`,
130+
wantImport: "go.uber.org/zap",
131+
wantType: "Logger",
132+
wantMethod: "Info",
133+
wantMatch: MethodCallMatchAny,
134+
},
135+
{
136+
name: "pointer-only",
137+
yaml: `method-call:
138+
receiver: "go.uber.org/zap.Logger"
139+
name: Info
140+
match: pointer-only`,
141+
wantImport: "go.uber.org/zap",
142+
wantType: "Logger",
143+
wantMethod: "Info",
144+
wantMatch: MethodCallMatchPointerOnly,
145+
},
146+
{
147+
name: "value-only",
148+
yaml: `method-call:
149+
receiver: "go.uber.org/zap.SugaredLogger"
150+
name: Debugw
151+
match: value-only`,
152+
wantImport: "go.uber.org/zap",
153+
wantType: "SugaredLogger",
154+
wantMethod: "Debugw",
155+
wantMatch: MethodCallMatchValueOnly,
156+
},
157+
{
158+
name: "pointer sigil in receiver is rejected",
159+
yaml: `method-call:
160+
receiver: "*go.uber.org/zap.Logger"
161+
name: Info`,
162+
wantErr: true,
163+
},
164+
{
165+
name: "missing receiver is rejected",
166+
yaml: `method-call:
167+
name: Info`,
168+
wantErr: true,
169+
},
170+
{
171+
name: "missing name is rejected",
172+
yaml: `method-call:
173+
receiver: "go.uber.org/zap.Logger"`,
174+
wantErr: true,
175+
},
176+
}
177+
178+
fn, ok := unmarshalers["method-call"]
179+
require.True(t, ok, "method-call unmarshaler must be registered")
180+
181+
for _, tt := range tests {
182+
t.Run(tt.name, func(t *testing.T) {
183+
var data map[string]any
184+
err := yaml.Unmarshal([]byte(tt.yaml), &data)
185+
require.NoError(t, err)
186+
187+
node, err := yaml.ValueToNode(data["method-call"])
188+
require.NoError(t, err)
189+
190+
result, err := fn(gocontext.Background(), node)
191+
if tt.wantErr {
192+
assert.Error(t, err)
193+
return
194+
}
195+
require.NoError(t, err)
196+
197+
m, ok := result.(*methodCall)
198+
require.True(t, ok)
199+
assert.Equal(t, tt.wantImport, m.Receiver.ImportPath)
200+
assert.Equal(t, tt.wantType, m.Receiver.Name)
201+
assert.Equal(t, tt.wantMethod, m.Name)
202+
assert.Equal(t, tt.wantMatch, m.Match)
203+
})
204+
}
205+
}

0 commit comments

Comments
 (0)