diff --git a/cmd/wire/main.go b/cmd/wire/main.go index fa37e51d..30433710 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -101,6 +101,7 @@ type genCmd struct { headerFile string prefixFileName string tags string + recursion bool } func (*genCmd) Name() string { return "gen" } @@ -119,6 +120,47 @@ func (cmd *genCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + f.BoolVar(&cmd.recursion, "recursion", false, "gen creates the wire_gen.go file for all sub [packages].") +} + +func (cmd *genCmd) findAllSubDir(dirs []string) ([]string, error) { + retDirs := make([]string, len(dirs), len(dirs)*1024) + copy(retDirs, dirs) + realDirs := make([]string, 0, len(dirs)*1024) + for index := 0; index < len(retDirs); index++ { + dir, err := os.Open(retDirs[index]) + if err != nil { + return nil, err + } + stat, err := dir.Stat() + if err != nil { + return nil, err + } + if !stat.IsDir() { + realDirs = append(realDirs, retDirs[index]) + continue + } + + subdirs, err := dir.ReadDir(0) + if err != nil { + return nil, err + } + + dirName := strings.TrimSuffix(retDirs[index], "/") + "/" + hasFile := false + for _, v := range subdirs { + if !v.IsDir() { + hasFile = true + continue + } + retDirs = append(retDirs, dirName+v.Name()) + } + + if hasFile { + realDirs = append(realDirs, retDirs[index]) + } + } + return realDirs, nil } func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { @@ -136,7 +178,18 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa opts.PrefixOutputFile = cmd.prefixFileName opts.Tags = cmd.tags - outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) + // 递归所有目录 + patterns, err := packages(f), (error)(nil) + if cmd.recursion { + patterns, err = cmd.findAllSubDir(patterns) + } + + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + + outs, errs := wire.Generate(ctx, wd, os.Environ(), patterns, opts) if len(errs) > 0 { logErrors(errs) log.Println("generate failed")