Skip to content

Commit 6e0a1a4

Browse files
authored
feat: support generic funcs (alibaba#628)
* feat: support generic funcs * fux * fix * fix
1 parent d6881ec commit 6e0a1a4

File tree

6 files changed

+407
-84
lines changed

6 files changed

+407
-84
lines changed

tool/ast/primitives.go

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919
"go/token"
2020

21+
"github.com/alibaba/loongsuite-go-agent/tool/util"
2122
"github.com/dave/dst"
2223
)
2324

@@ -42,13 +43,34 @@ func AddressOf(name string) *dst.UnaryExpr {
4243
return &dst.UnaryExpr{Op: token.AND, X: Ident(name)}
4344
}
4445

45-
func CallTo(name string, args []dst.Expr) *dst.CallExpr {
46+
// CallTo creates a call expression to a function with optional type arguments for generics.
47+
// For non-generic functions (typeArgs is nil or empty), creates a simple call: Foo(args...)
48+
// For generic functions with type arguments, creates: Foo[T1, T2](args...)
49+
func CallTo(name string, typeArgs *dst.FieldList, args []dst.Expr) *dst.CallExpr {
50+
if typeArgs == nil || len(typeArgs.List) == 0 {
51+
return &dst.CallExpr{
52+
Fun: &dst.Ident{Name: name},
53+
Args: args,
54+
}
55+
}
56+
57+
var indices []dst.Expr
58+
for _, field := range typeArgs.List {
59+
for _, ident := range field.Names {
60+
indices = append(indices, Ident(ident.Name))
61+
}
62+
}
63+
var fun dst.Expr
64+
if len(indices) == 1 {
65+
fun = IndexExpr(Ident(name), indices[0])
66+
} else {
67+
fun = IndexListExpr(Ident(name), indices)
68+
}
4669
return &dst.CallExpr{
47-
Fun: &dst.Ident{Name: name},
70+
Fun: fun,
4871
Args: args,
4972
}
5073
}
51-
5274
func Ident(name string) *dst.Ident {
5375
return &dst.Ident{
5476
Name: name,
@@ -102,13 +124,27 @@ func SelectorExpr(x dst.Expr, sel string) *dst.SelectorExpr {
102124
}
103125
}
104126

127+
func Ellipsis(elt dst.Expr) *dst.Ellipsis {
128+
return &dst.Ellipsis{
129+
Elt: elt,
130+
}
131+
}
132+
105133
func IndexExpr(x dst.Expr, index dst.Expr) *dst.IndexExpr {
106134
return &dst.IndexExpr{
107135
X: dst.Clone(x).(dst.Expr),
108136
Index: dst.Clone(index).(dst.Expr),
109137
}
110138
}
111139

140+
func IndexListExpr(x dst.Expr, indices []dst.Expr) *dst.IndexListExpr {
141+
e := util.AssertType[dst.Expr](dst.Clone(x))
142+
return &dst.IndexListExpr{
143+
X: e,
144+
Indices: indices,
145+
}
146+
}
147+
112148
func TypeAssertExpr(x dst.Expr, typ dst.Expr) *dst.TypeAssertExpr {
113149
return &dst.TypeAssertExpr{
114150
X: x,

tool/ast/shared.go

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,48 @@ func FindFuncDeclWithoutRecv(root *dst.File, funcName string) *dst.FuncDecl {
153153
return decls[0]
154154
}
155155

156+
// stripGenericTypes extracts the base type name from a receiver expression,
157+
// handling both generic and non-generic types.
158+
// For example:
159+
// - *MyStruct -> *MyStruct
160+
// - MyStruct -> MyStruct
161+
// - *GenStruct[T] -> *GenStruct
162+
// - GenStruct[T] -> GenStruct
163+
func stripGenericTypes(recvTypeExpr dst.Expr) string {
164+
switch expr := recvTypeExpr.(type) {
165+
case *dst.StarExpr: // func (*Recv)T or func (*Recv[T])T
166+
// Check if X is an Ident (non-generic) or IndexExpr/IndexListExpr (generic)
167+
switch x := expr.X.(type) {
168+
case *dst.Ident:
169+
// Non-generic pointer receiver: *MyStruct
170+
return "*" + x.Name
171+
case *dst.IndexExpr:
172+
// Generic pointer receiver with single type param: *GenStruct[T]
173+
if baseIdent, ok := x.X.(*dst.Ident); ok {
174+
return "*" + baseIdent.Name
175+
}
176+
case *dst.IndexListExpr:
177+
// Generic pointer receiver with multiple type params: *GenStruct[T, U]
178+
if baseIdent, ok := x.X.(*dst.Ident); ok {
179+
return "*" + baseIdent.Name
180+
}
181+
}
182+
case *dst.Ident: // func (Recv)T
183+
return expr.Name
184+
case *dst.IndexExpr:
185+
// Generic value receiver with single type param: GenStruct[T]
186+
if baseIdent, ok := expr.X.(*dst.Ident); ok {
187+
return baseIdent.Name
188+
}
189+
case *dst.IndexListExpr:
190+
// Generic value receiver with multiple type params: GenStruct[T, U]
191+
if baseIdent, ok := expr.X.(*dst.Ident); ok {
192+
return baseIdent.Name
193+
}
194+
}
195+
return ""
196+
}
197+
156198
func FindFuncDecl(root *dst.File, function string, receiverType string) []*dst.FuncDecl {
157199
decls := findFuncDecls(root, func(funcDecl *dst.FuncDecl) bool {
158200
return function == funcDecl.Name.Name
@@ -169,28 +211,19 @@ func FindFuncDecl(root *dst.File, function string, receiverType string) []*dst.F
169211
filtered = append(filtered, funcDecl)
170212
}
171213
}
172-
switch recvTypeExpr := funcDecl.Recv.List[0].Type.(type) {
173-
case *dst.StarExpr:
174-
if _, ok := recvTypeExpr.X.(*dst.Ident); !ok {
175-
// This is a generic type, we don't support it yet
176-
continue
177-
}
178-
t := "*" + recvTypeExpr.X.(*dst.Ident).Name
179-
if re.MatchString(t) {
180-
filtered = append(filtered, funcDecl)
181-
}
182-
case *dst.Ident:
183-
t := recvTypeExpr.Name
184-
if re.MatchString(t) {
185-
filtered = append(filtered, funcDecl)
186-
}
187-
case *dst.IndexExpr:
188-
// This is a generic type, we don't support it yet
189-
continue
190-
default:
214+
215+
// Receiver type is specified, and target function has receiver
216+
// Match both func name and receiver type
217+
recvTypeExpr := funcDecl.Recv.List[0].Type
218+
baseType := stripGenericTypes(recvTypeExpr)
219+
220+
if baseType == "" {
191221
msg := fmt.Sprintf("unexpected receiver type: %T", recvTypeExpr)
192222
util.Unimplemented(msg)
193223
}
224+
if re.MatchString(baseType) {
225+
filtered = append(filtered, funcDecl)
226+
}
194227
}
195228
return filtered
196229
}
@@ -229,3 +262,46 @@ func FindStructDecl(root *dst.File, structName string) *dst.GenDecl {
229262
}
230263
return nil
231264
}
265+
266+
// SplitMultiNameFields splits fields that have multiple names into separate fields.
267+
// For example, a field like "a, b int" becomes two fields: "a int" and "b int".
268+
func SplitMultiNameFields(fieldList *dst.FieldList) *dst.FieldList {
269+
if fieldList == nil {
270+
return nil
271+
}
272+
result := &dst.FieldList{List: []*dst.Field{}}
273+
for _, field := range fieldList.List {
274+
// Handle unnamed fields (e.g., embedded types) or fields with single/multiple names
275+
namesToProcess := field.Names
276+
if len(namesToProcess) == 0 {
277+
// For unnamed fields, create one field with no names
278+
namesToProcess = []*dst.Ident{nil}
279+
}
280+
281+
for _, name := range namesToProcess {
282+
clonedType := util.AssertType[dst.Expr](dst.Clone(field.Type))
283+
284+
var names []*dst.Ident
285+
if name != nil {
286+
clonedName := util.AssertType[*dst.Ident](dst.Clone(name))
287+
names = []*dst.Ident{clonedName}
288+
}
289+
290+
newField := &dst.Field{
291+
Names: names,
292+
Type: clonedType,
293+
}
294+
result.List = append(result.List, newField)
295+
}
296+
}
297+
return result
298+
}
299+
300+
// CloneTypeParams safely clones a type parameter field list for generic functions.
301+
// Returns nil if the input is nil.
302+
func CloneTypeParams(typeParams *dst.FieldList) *dst.FieldList {
303+
if typeParams == nil {
304+
return nil
305+
}
306+
return util.AssertType[*dst.FieldList](dst.Clone(typeParams))
307+
}

tool/ex/error.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ func Wrapf(previousErr error, format string, args ...any) error {
133133
return wrapOrCreate(previousErr, format, args...)
134134
}
135135

136+
func New(message string) error {
137+
return wrapOrCreate(nil, "%s", message)
138+
}
136139
func Newf(format string, args ...any) error {
137140
return wrapOrCreate(nil, format, args...)
138141
}

tool/instrument/apply_func.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ func (rp *RuleProcessor) createTJumpIf(t *rules.InstFuncRule, funcDecl *dst.Func
169169
argsToOnExit := createHookArgs(retVals)
170170
argCallContext := ast.Ident(trampolineCallContextName + varSuffix)
171171
argsToOnExit = append([]dst.Expr{argCallContext}, argsToOnExit...)
172-
onEnterCall := ast.CallTo(makeName(t, funcDecl, true), argsToOnEnter)
173-
onExitCall := ast.CallTo(makeName(t, funcDecl, false), argsToOnExit)
172+
onEnterCall := ast.CallTo(makeName(t, funcDecl, true), nil, argsToOnEnter)
173+
onExitCall := ast.CallTo(makeName(t, funcDecl, false), nil, argsToOnExit)
174174
tjumpInit := ast.DefineStmts(
175175
ast.Exprs(
176176
ast.Ident(trampolineCallContextName+varSuffix),

0 commit comments

Comments
 (0)