Skip to content

Commit 574acbc

Browse files
committed
Add flag to cmg to use testify/assert in mocks
- Add a `-testify` flag which replaces the usage of `testing.T`'s `t.Helper()` and `t.Fail()` in mocks generated by `cmg` with `assert.Fail()`. When used, this gives the benefit of printing out a stack trace to identify why an unexpected call is happening and will be familiar to those already using testify/assert for tests.
1 parent 06e0199 commit 574acbc

File tree

7 files changed

+139
-18
lines changed

7 files changed

+139
-18
lines changed

mock/cmd/cmg/main.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ import (
1212

1313
func main() {
1414
var (
15-
gSet = flag.NewFlagSet("global", flag.ExitOnError)
16-
debug, help, h *bool
17-
addGlobals = func(set *flag.FlagSet) {
15+
gSet = flag.NewFlagSet("global", flag.ExitOnError)
16+
debug, testify, help, h *bool
17+
addGlobals = func(set *flag.FlagSet) {
1818
debug = set.Bool("debug", false, "Print debug output")
19+
testify = set.Bool("testify", false, "Use github.com/stretchr/testify for assertions")
1920
help = set.Bool("help", false, "Print help information")
2021
h = set.Bool("h", false, "Print help information")
2122
}
@@ -67,7 +68,7 @@ func main() {
6768
} else {
6869
ctx = log.Context(ctx)
6970
}
70-
err := cluemockgen.Generate(ctx, args, "")
71+
err := cluemockgen.Generate(ctx, args, "", *testify)
7172
if err != nil {
7273
os.Exit(1)
7374
}

mock/cmd/cmg/pkg/generate.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ import (
1414
"goa.design/clue/mock/cmd/cmg/pkg/parse"
1515
)
1616

17-
func Generate(ctx context.Context, patterns []string, dir string) error {
17+
// Generate generates the mocks for the given patterns and directory.
18+
// If testify is true, it uses github.com/stretchr/testify for assertions.
19+
func Generate(ctx context.Context, patterns []string, dir string, testify bool) error {
1820
ps, err := parse.LoadPackages(patterns, dir)
1921
if err != nil {
2022
log.Error(ctx, err)
@@ -24,7 +26,7 @@ func Generate(ctx context.Context, patterns []string, dir string) error {
2426
var errs []error
2527

2628
for _, p := range ps {
27-
err = generatePackage(ctx, p)
29+
err = generatePackage(ctx, p, testify)
2830
if err != nil {
2931
errs = append(errs, err)
3032
}
@@ -37,9 +39,11 @@ func Generate(ctx context.Context, patterns []string, dir string) error {
3739
return nil
3840
}
3941

40-
func generatePackage(ctx context.Context, p parse.Package) error {
42+
// generatePackage generates the mocks for the given package.
43+
// If testify is true, it uses github.com/stretchr/testify for assertions.
44+
func generatePackage(ctx context.Context, p parse.Package, testify bool) error {
4145
ctx = log.With(ctx, log.KV{K: "pkg name", V: p.Name()})
42-
log.Print(ctx, log.KV{K: "pkg path", V: p.PkgPath()})
46+
log.Print(ctx, log.KV{K: "pkg path", V: p.PkgPath()}, log.KV{K: "testify", V: testify})
4347

4448
is, err := p.Interfaces()
4549
if err != nil {
@@ -75,7 +79,7 @@ func generatePackage(ctx context.Context, p parse.Package) error {
7579
}
7680
}
7781
for file, interfaces := range interfacesByFile {
78-
err = generateFile(ctx, p, file, interfaces)
82+
err = generateFile(ctx, p, file, interfaces, testify)
7983
if err != nil {
8084
return err
8185
}
@@ -84,13 +88,15 @@ func generatePackage(ctx context.Context, p parse.Package) error {
8488
return nil
8589
}
8690

87-
func generateFile(ctx context.Context, p parse.Package, file string, interfaces []parse.Interface) error {
91+
// generateFile generates the mocks for the given file.
92+
// If testify is true, it uses github.com/stretchr/testify for assertions.
93+
func generateFile(ctx context.Context, p parse.Package, file string, interfaces []parse.Interface, testify bool) error {
8894
ctx = log.With(ctx, log.KV{K: "file", V: file})
8995
interfaceNames := make([]string, len(interfaces))
9096
for j, i := range interfaces {
9197
interfaceNames[j] = i.Name()
9298
}
93-
log.Print(ctx, log.KV{K: "interface names", V: interfaceNames})
99+
log.Print(ctx, log.KV{K: "interface names", V: interfaceNames}, log.KV{K: "testify", V: testify})
94100

95101
dir, baseFile := filepath.Split(file)
96102
mocksDir := filepath.Join(dir, "mocks")
@@ -118,7 +124,7 @@ func generateFile(ctx context.Context, p parse.Package, file string, interfaces
118124
}
119125
}()
120126

121-
mocks := generate.NewMocks("mock", p, interfaces, Version)
127+
mocks := generate.NewMocks("mock", p, interfaces, Version, testify)
122128
if err := mocks.Render(f); err != nil {
123129
log.Error(ctx, err)
124130
return err

mock/cmd/cmg/pkg/generate/_tests/testify/mocks/testify.go

Lines changed: 52 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package testify
2+
3+
type (
4+
Testify interface {
5+
Simple(a, b int) bool
6+
}
7+
)

mock/cmd/cmg/pkg/generate/mocks.go

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,72 @@ import (
1010
)
1111

1212
type (
13+
// Mocks is the interface for the mocks.
1314
Mocks interface {
15+
// PkgName returns the package name.
1416
PkgName() string
17+
// PkgImport returns the package import.
1518
PkgImport() Import
19+
// StdImports returns the standard imports.
1620
StdImports() []Import
21+
// ExtImports returns the external imports.
1722
ExtImports() []Import
23+
// IntImports returns the internal imports.
1824
IntImports() []Import
25+
// Interfaces returns the interfaces.
1926
Interfaces() []Interface
27+
// ToolVersion returns the tool version.
2028
ToolVersion() string
29+
// ToolCommandLine returns the tool command line.
2130
ToolCommandLine() string
31+
// Testify returns true if testify should be used.
32+
Testify() bool
33+
// Render renders the mocks to the given writer.
2234
Render(w io.Writer) error
2335
}
2436

37+
// ToolVersionFunc is the function type for the tool version.
2538
ToolVersionFunc func() string
2639

40+
// mocks is the implementation of the Mocks interface.
2741
mocks struct {
2842
pkgName, pkgPath string
2943
pkgImport Import
3044
stdImports, extImports, intImports []Import
3145
interfaces []Interface
3246
toolVersionFunc ToolVersionFunc
47+
testify bool
3348
}
3449
)
3550

3651
//go:embed mocks.go.tmpl
3752
var mocksStr string
3853

3954
var (
55+
// mocksTmpl is the template for the mocks.
4056
mocksTmpl = template.Must(template.New("mocks").Parse(mocksStr))
4157
)
4258

43-
func NewMocks(prefix string, p parse.Package, interfaces []parse.Interface, toolVersionFunc ToolVersionFunc) Mocks {
59+
// NewMocks creates a new Mocks instance.
60+
// If testify is true, it uses github.com/stretchr/testify for assertions.
61+
func NewMocks(prefix string, p parse.Package, interfaces []parse.Interface, toolVersionFunc ToolVersionFunc, testify bool) Mocks {
4462
var (
4563
stdImports = importMap{"testing": newImport("testing")}
4664
extImports = importMap{"mock": newImport("goa.design/clue/mock")}
4765
intImports = make(importMap)
48-
modPath = p.ModPath()
49-
m = &mocks{
66+
)
67+
if testify {
68+
extImports["assert"] = newImport("github.com/stretchr/testify/assert")
69+
}
70+
71+
var (
72+
modPath = p.ModPath()
73+
m = &mocks{
5074
pkgName: prefix + p.Name(),
5175
pkgPath: p.PkgPath(),
5276
pkgImport: addImport(newImport(p.PkgPath(), p.Name()), stdImports, extImports, intImports, modPath),
5377
toolVersionFunc: toolVersionFunc,
78+
testify: testify,
5479
}
5580
typeNames = make(typeMap)
5681
typeZeros = make(typeMap)
@@ -78,38 +103,52 @@ func NewMocks(prefix string, p parse.Package, interfaces []parse.Interface, tool
78103
return m
79104
}
80105

106+
// PkgName returns the package name.
81107
func (m *mocks) PkgName() string {
82108
return m.pkgName
83109
}
84110

111+
// PkgImport returns the package import.
85112
func (m *mocks) PkgImport() Import {
86113
return m.pkgImport
87114
}
88115

116+
// StdImports returns the standard imports.
89117
func (m *mocks) StdImports() []Import {
90118
return m.stdImports
91119
}
92120

121+
// ExtImports returns the external imports.
93122
func (m *mocks) ExtImports() []Import {
94123
return m.extImports
95124
}
96125

126+
// IntImports returns the internal imports.
97127
func (m *mocks) IntImports() []Import {
98128
return m.intImports
99129
}
100130

131+
// Interfaces returns the interfaces.
101132
func (m *mocks) Interfaces() []Interface {
102133
return m.interfaces
103134
}
104135

136+
// ToolVersion returns the tool version.
105137
func (m *mocks) ToolVersion() string {
106138
return m.toolVersionFunc()
107139
}
108140

141+
// ToolCommandLine returns the tool command line.
109142
func (m *mocks) ToolCommandLine() string {
110143
return "$ cmg gen " + m.pkgPath
111144
}
112145

146+
// Testify returns true if testify should be used.
147+
func (m *mocks) Testify() bool {
148+
return m.testify
149+
}
150+
151+
// Render renders the mocks to the given writer.
113152
func (m *mocks) Render(w io.Writer) error {
114153
return mocksTmpl.Execute(w, m)
115154
}

mock/cmd/cmg/pkg/generate/mocks.go.tmpl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,18 @@ import (
1717
{{- end }}
1818
)
1919

20+
{{ $testify := .Testify -}}
2021
type (
2122
{{- range $index, $interface := .Interfaces }}
2223
{{- if ge $index 1 }}
2324
{{ end }}
2425
{{ .Name }}{{ .TypeParameters }} struct {
2526
m *mock.Mock
27+
{{- if $testify }}
28+
assert *assert.Assertions
29+
{{- else }}
2630
t *testing.T
31+
{{- end }}
2732
}
2833
{{ range .Methods }}
2934
{{ printf "%v%v" .Func $interface.TypeParameters | printf $interface.MaxFuncLenFmt }} func({{ .Parameters }}){{ if .Results }} {{ .Results }}{{ end }}
@@ -36,7 +41,7 @@ type (
3641
{{ end }}
3742
func {{ .Constructor }}{{ .TypeParameters }}(t *testing.T) *{{ .Name }}{{ .TypeParameterVars }} {
3843
var (
39-
{{ .Var | printf ($import.AliasOrPkgName | .ConstructorFmt) }} = &{{ .Name }}{{ .TypeParameterVars }}{mock.New(), t}
44+
{{ .Var | printf ($import.AliasOrPkgName | .ConstructorFmt) }} = &{{ .Name }}{{ .TypeParameterVars }}{mock.New(), {{ if $testify }}assert.New(t){{ else }}t{{ end }}}
4045
_ {{ $import.AliasOrPkgName }}.{{ .Name }}{{ .TypeParameterVars }} = m
4146
)
4247
return {{ .Var }}
@@ -59,8 +64,12 @@ func ({{ .InterfaceVar }} *{{ $interface.Name }}{{ $interface.TypeParameterVars
5964
return
6065
{{- end }}
6166
}
67+
{{- if $testify }}
68+
{{ .InterfaceVar }}.assert.Fail("unexpected {{ .Name }} call")
69+
{{- else }}
6270
{{ .InterfaceVar }}.t.Helper()
6371
{{ .InterfaceVar }}.t.Error("unexpected {{ .Name }} call")
72+
{{- end }}
6473
{{- if .ZeroResults }}
6574
return {{ .ZeroResults }}
6675
{{- end }}

mock/cmd/cmg/pkg/generate/mocks_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ func TestMocks_Render(t *testing.T) {
2323
cases := []struct {
2424
Name, Pattern string
2525
ExpectedFiles []string
26+
Testify bool
2627
}{
2728
{
2829
Name: "extensive",
@@ -34,6 +35,12 @@ func TestMocks_Render(t *testing.T) {
3435
Pattern: "./conflicts",
3536
ExpectedFiles: []string{"conflicts.go"},
3637
},
38+
{
39+
Name: "testify",
40+
Pattern: "./testify",
41+
ExpectedFiles: []string{"testify.go"},
42+
Testify: true,
43+
},
3744
}
3845

3946
for _, tc := range cases {
@@ -67,7 +74,7 @@ func TestMocks_Render(t *testing.T) {
6774
require.NoError(err)
6875
t.Cleanup(func() { assert.NoError(of.Close()) })
6976

70-
m := NewMocks("mock", p, is, toolVersion)
77+
m := NewMocks("mock", p, is, toolVersion, tc.Testify)
7178
require.NoError(m.Render(of))
7279
}
7380

@@ -82,7 +89,7 @@ func TestMocks_Render(t *testing.T) {
8289

8390
for f, is := range interfacesByFile {
8491
f := filepath.Join(mocksDir, f)
85-
m := NewMocks("mock", p, is, toolVersion)
92+
m := NewMocks("mock", p, is, toolVersion, tc.Testify)
8693
b := &bytes.Buffer{}
8794

8895
err := m.Render(b)

0 commit comments

Comments
 (0)