Skip to content

Commit 9ddb690

Browse files
committed
allow per field flags
1 parent feb57ae commit 9ddb690

5 files changed

Lines changed: 139 additions & 32 deletions

File tree

cmd/gtrace/main.go

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ func main() {
258258
for i, c := range v.List {
259259
logf("#%d comment %q", i, c.Text)
260260

261-
text := strings.TrimPrefix(c.Text, "//gtrace:")
262-
if c.Text != text {
261+
text, ok := TrimConfigComment(c.Text)
262+
if ok {
263263
if item == nil {
264264
item = &GenItem{
265265
File: pkgFiles[i],
@@ -284,6 +284,7 @@ func main() {
284284
}
285285
return false
286286
}
287+
287288
return true
288289
})
289290
}
@@ -298,20 +299,39 @@ func main() {
298299
}
299300
for _, field := range item.StructType.Fields.List {
300301
name := field.Names[0].Name
301-
if fn, ok := field.Type.(*ast.FuncType); ok {
302-
f, err := buildFunc(info, fn)
303-
if err != nil {
304-
log.Printf(
305-
"skipping hook %s due to error: %v",
306-
name, err,
307-
)
308-
continue
302+
fn, ok := field.Type.(*ast.FuncType)
303+
if !ok {
304+
continue
305+
}
306+
f, err := buildFunc(info, fn)
307+
if err != nil {
308+
log.Printf(
309+
"skipping hook %s due to error: %v",
310+
name, err,
311+
)
312+
continue
313+
}
314+
var config GenConfig
315+
if doc := field.Doc; doc != nil {
316+
for _, line := range doc.List {
317+
text, ok := TrimConfigComment(line.Text)
318+
if !ok {
319+
continue
320+
}
321+
err := config.ParseComment(text)
322+
if err != nil {
323+
log.Fatalf(
324+
"malformed comment string: %q: %v",
325+
text, err,
326+
)
327+
}
309328
}
310-
t.Hooks = append(t.Hooks, Hook{
311-
Name: name,
312-
Func: f,
313-
})
314329
}
330+
t.Hooks = append(t.Hooks, Hook{
331+
Name: name,
332+
Func: f,
333+
Flag: item.GenConfig.Flag | config.Flag,
334+
})
315335
}
316336
p.Traces = append(p.Traces, t)
317337
}
@@ -412,6 +432,7 @@ type Trace struct {
412432
type Hook struct {
413433
Name string
414434
Func Func
435+
Flag GenFlag
415436
}
416437

417438
type Param struct {
@@ -442,44 +463,55 @@ const (
442463
GenAll = ^GenFlag(0)
443464
)
444465

445-
type GenItem struct {
446-
File *os.File
447-
Ident *ast.Ident
448-
TypeSpec *ast.TypeSpec
449-
StructType *ast.StructType
450-
466+
type GenConfig struct {
451467
Flag GenFlag
452468
}
453469

454-
func (x *GenItem) ParseComment(text string) (err error) {
470+
func TrimConfigComment(text string) (string, bool) {
471+
s := strings.TrimPrefix(text, "//gtrace:")
472+
if text != s {
473+
return s, true
474+
}
475+
return "", false
476+
}
477+
478+
func (g *GenConfig) ParseComment(text string) (err error) {
455479
prefix, text := split(text, ' ')
456480
switch prefix {
457481
case "gen":
458482
case "set":
459-
return x.ParseParameter(text)
483+
return g.ParseParameter(text)
460484
default:
461485
return fmt.Errorf("unknown prefix: %q", prefix)
462486
}
463487
return nil
464488
}
465489

466-
func (x *GenItem) ParseParameter(text string) (err error) {
490+
func (g *GenConfig) ParseParameter(text string) (err error) {
467491
text = strings.TrimSpace(text)
468492
param, _ := split(text, '=')
469493
if param == "" {
470494
return nil
471495
}
472496
switch param {
473497
case "shortcut":
474-
x.Flag |= GenShortcut
498+
g.Flag |= GenShortcut
475499
case "context":
476-
x.Flag |= GenContext
500+
g.Flag |= GenContext
477501
default:
478502
return fmt.Errorf("unexpected parameter: %q", param)
479503
}
480504
return nil
481505
}
482506

507+
type GenItem struct {
508+
GenConfig
509+
File *os.File
510+
Ident *ast.Ident
511+
TypeSpec *ast.TypeSpec
512+
StructType *ast.StructType
513+
}
514+
483515
func split(s string, c byte) (s1, s2 string) {
484516
i := strings.IndexByte(s, c)
485517
if i == -1 {

cmd/gtrace/writer.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ func (w *Writer) Write(p Package) error {
8787
}
8888
}
8989
for _, trace := range p.Traces {
90-
if !trace.Flag.Has(GenShortcut) {
91-
continue
92-
}
9390
for _, hook := range trace.Hooks {
91+
if !hook.Flag.Has(GenShortcut) {
92+
continue
93+
}
9494
if w.Stub {
9595
w.stubHookShortcut(trace, hook)
9696
} else {
@@ -203,10 +203,10 @@ func unwrapStruct(t types.Type) (n *types.Named, s *types.Struct) {
203203
return
204204
}
205205

206-
func (w *Writer) traceFuncImports(dst []dep, trace Trace, fn Func) []dep {
206+
func (w *Writer) funcImports(dst []dep, flag GenFlag, fn Func) []dep {
207207
for _, p := range fn.Params {
208208
dst = w.typeImports(dst, p.Type)
209-
if !trace.Flag.Has(GenShortcut) {
209+
if !flag.Has(GenShortcut) {
210210
continue
211211
}
212212
if _, s := unwrapStruct(p.Type); s != nil {
@@ -218,7 +218,7 @@ func (w *Writer) traceFuncImports(dst []dep, trace Trace, fn Func) []dep {
218218
}
219219
}
220220
for _, fn := range fn.Result {
221-
dst = w.traceFuncImports(dst, trace, fn)
221+
dst = w.funcImports(dst, flag, fn)
222222
}
223223
return dst
224224
}
@@ -232,7 +232,7 @@ func (w *Writer) traceImports(dst []dep, t Trace) []dep {
232232
})
233233
}
234234
for _, h := range t.Hooks {
235-
dst = w.traceFuncImports(dst, t, h.Func)
235+
dst = w.funcImports(dst, h.Flag, h.Func)
236236
}
237237
return dst
238238
}

test/test_shortcut.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package test
2+
3+
//go:generate gtrace -v
4+
5+
//gtrace:gen
6+
type ShortcutPerFieldTrace struct {
7+
//gtrace:set shortcut
8+
OnFoo func()
9+
OnBar func()
10+
}

test/test_shortcut_gtrace.go

Lines changed: 52 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

test/test_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ func TestCompose(t *testing.T) {
186186
}
187187
}
188188

189+
func TestShortcutPerFieldTrace(t *testing.T) {
190+
var called bool
191+
t0 := ShortcutPerFieldTrace{
192+
OnFoo: func() {
193+
called = true
194+
},
195+
}
196+
shortcutPerFieldTraceOnFoo(t0)
197+
if !called {
198+
t.Fatalf("hook wasn't called")
199+
}
200+
}
201+
189202
func TestBuildTagTrace(t *testing.T) {
190203
t0 := BuildTagTrace{
191204
OnSomethingA: func() func() {

0 commit comments

Comments
 (0)