Skip to content

Commit 90f7ec4

Browse files
committed
handle vendored pkgs
1 parent fe0ebb8 commit 90f7ec4

File tree

3 files changed

+56
-154
lines changed

3 files changed

+56
-154
lines changed

cmd/mocker/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ func main() {
3333

3434
m, err := mocker.New(src, pkg, iface, prefix, suffix, w)
3535
if err != nil {
36-
log.Fatal("failed to instantiate mocker")
36+
log.Fatal("mocker: failed to instantiate")
3737
}
3838

3939
if err = m.Mock(); err != nil {
40-
log.Fatalf("failed to mock: %v", err)
40+
log.Fatalf("mocker: failed to mock: %v", err)
4141
}
4242

4343
if out != nil {

pkg/mocker/importer.go

Lines changed: 0 additions & 101 deletions
This file was deleted.

pkg/mocker/mocker.go

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ package mocker
33
import (
44
"bytes"
55
"fmt"
6-
"go/ast"
76
"go/format"
8-
goimporter "go/importer"
97
"go/parser"
108
"go/token"
119
"go/types"
1210
"io"
1311
"os"
12+
"path/filepath"
1413
"strings"
1514
"text/template"
1615

1716
"github.com/pkg/errors"
17+
"golang.org/x/tools/go/loader"
1818
)
1919

2020
type mocker struct {
@@ -51,57 +51,31 @@ func (m *mocker) Mock() error {
5151
}
5252
tmpl, err := template.New("mocker").Funcs(tmplFns).Parse(tmpl)
5353
if err != nil {
54-
return errors.Wrap(err, "mocker: failed to parse template")
54+
return errors.Wrap(err, "failed to parse template")
5555
}
5656
f := file{Pkg: *m.pkg, Imports: []iimport{{Path: "sync"}}}
57-
for _, pkg := range pkgs {
58-
i := 0
59-
files := make([]*ast.File, len(pkg.Files))
60-
for _, f := range pkg.Files {
61-
files[i] = f
62-
i++
63-
}
64-
cfg := types.Config{Importer: &importer{src: *m.src, pkgs: make(map[string]*types.Package), base: goimporter.Default()}}
65-
tpkg, err := cfg.Check(*m.src, fset, files, nil)
66-
if err != nil {
67-
return errors.Wrap(err, "mocker: failed to type check pkg")
57+
58+
pkgInfo, err := m.pkgInfo(*m.src)
59+
if err != nil {
60+
return errors.Wrap(err, "failed to get pkg info")
61+
}
62+
for _, n := range *m.iface {
63+
ifaceobj := pkgInfo.Pkg.Scope().Lookup(n)
64+
if ifaceobj == nil {
65+
return fmt.Errorf("failed to find interface: %s", n)
6866
}
69-
for _, f := range files {
70-
for _, d := range f.Decls {
71-
gd, ok := d.(*ast.GenDecl)
72-
if !ok {
73-
continue
74-
}
75-
for _, s := range gd.Specs {
76-
is, ok := s.(*ast.ImportSpec)
77-
if !ok {
78-
continue
79-
}
80-
if is.Name != nil {
81-
i := iimport{Name: is.Name.Name, Path: strings.Replace(is.Path.Value, `"`, "", -1)}
82-
m.imports.named[i.Path] = i
83-
}
84-
}
85-
}
67+
if !types.IsInterface(ifaceobj.Type()) {
68+
return errors.Wrap(err, fmt.Sprintf("%s (%s) is not an interface", n, ifaceobj.Type().String()))
8669
}
87-
for _, i := range *m.iface {
88-
ifaceobj := tpkg.Scope().Lookup(i)
89-
if ifaceobj == nil {
90-
return fmt.Errorf("mocker: failed to find interface %s", i)
91-
}
92-
if !types.IsInterface(ifaceobj.Type()) {
93-
return fmt.Errorf("mocker: not an interface %s", i)
94-
}
95-
tiface := ifaceobj.Type().Underlying().(*types.Interface).Complete()
96-
iface := iface{Name: i, Suffix: *m.suffix, Prefix: *m.prefix}
97-
for i := 0; i < tiface.NumMethods(); i++ {
98-
met := tiface.Method(i)
99-
sig := met.Type().(*types.Signature)
100-
m := method{Name: met.Name(), Params: m.params(sig, sig.Params(), "in%d"), Returns: m.params(sig, sig.Results(), "out%d")}
101-
iface.Methods = append(iface.Methods, m)
102-
}
103-
f.Ifaces = append(f.Ifaces, iface)
70+
iiface := ifaceobj.Type().Underlying().(*types.Interface).Complete()
71+
iface := iface{Name: n, Suffix: *m.suffix, Prefix: *m.prefix}
72+
for i := 0; i < iiface.NumMethods(); i++ {
73+
met := iiface.Method(i)
74+
sig := met.Type().(*types.Signature)
75+
m := method{Name: met.Name(), Params: m.params(sig, sig.Params(), "in%d"), Returns: m.params(sig, sig.Results(), "out%d")}
76+
iface.Methods = append(iface.Methods, m)
10477
}
78+
f.Ifaces = append(f.Ifaces, iface)
10579
}
10680
for p, n := range m.imports.named {
10781
if _, ok := m.imports.all[p]; ok {
@@ -113,14 +87,14 @@ func (m *mocker) Mock() error {
11387
}
11488
var buf bytes.Buffer
11589
if err := tmpl.Execute(&buf, f); err != nil {
116-
return errors.Wrap(err, "mocker: failed to execute template")
90+
return errors.Wrap(err, "failed to execute template")
11791
}
11892
fmted, err := format.Source(buf.Bytes())
11993
if err != nil {
120-
return errors.Wrap(err, "mocker: failed to format file")
94+
return errors.Wrap(err, "failed to format file")
12195
}
12296
if _, err := m.w.Write(fmted); err != nil {
123-
return errors.Wrap(err, "mocker: failed to write file")
97+
return errors.Wrap(err, "failed to write file")
12498
}
12599
return nil
126100
}
@@ -247,3 +221,32 @@ func (m *mocker) params(sig *types.Signature, tuple *types.Tuple, format string)
247221
}
248222
return params
249223
}
224+
225+
func (m *mocker) pkgInfo(src string) (*loader.PackageInfo, error) {
226+
abs, err := filepath.Abs(src)
227+
if err != nil {
228+
return nil, errors.Wrap(err, "faild to get abs src path")
229+
}
230+
pkgPath := m.strip(abs)
231+
conf := loader.Config{
232+
ParserMode: parser.SpuriousErrors,
233+
Cwd: src,
234+
}
235+
conf.Import(pkgPath)
236+
loader, err := conf.Load()
237+
if err != nil {
238+
return nil, errors.Wrap(err, "failed to load program")
239+
}
240+
pkgInfo := loader.Package(pkgPath)
241+
if pkgInfo == nil {
242+
return nil, errors.New("unable to load package")
243+
}
244+
return pkgInfo, nil
245+
}
246+
247+
func (m *mocker) strip(pkg string) string {
248+
for _, path := range strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator)) {
249+
pkg = strings.TrimPrefix(pkg, filepath.Join(path, "src")+"/")
250+
}
251+
return pkg
252+
}

0 commit comments

Comments
 (0)