diff --git a/sql-migrate/config.go b/sql-migrate/config.go index 67906065..a845e4f8 100644 --- a/sql-migrate/config.go +++ b/sql-migrate/config.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "os" + "github.com/kelseyhightower/envconfig" "github.com/rubenv/sql-migrate" "gopkg.in/gorp.v1" "gopkg.in/yaml.v2" @@ -32,30 +33,34 @@ func ConfigFlags(f *flag.FlagSet) { } type Environment struct { - Dialect string `yaml:"dialect"` - DataSource string `yaml:"datasource"` - Dir string `yaml:"dir"` - TableName string `yaml:"table"` - SchemaName string `yaml:"schema"` + Dialect string `yaml:"dialect" envconfig:"DIALECT"` + DataSource string `yaml:"datasource" envconfig:"DATASOURCE"` + Dir string `yaml:"dir" envconfig:"DIR"` + TableName string `yaml:"table" envconfig:"TABLE"` + SchemaName string `yaml:"schema" envconfig:"SCHEMA"` } -func ReadConfig() (map[string]*Environment, error) { - file, err := ioutil.ReadFile(ConfigFile) - if err != nil { +func ConfigPresent() bool { + _, err := os.Stat(ConfigFile) + return !os.IsNotExist(err) +} + +func ReadEnv() (*Environment, error) { + env := &Environment{} + if err := envconfig.Process("", env); err != nil { return nil, err } + return env, nil +} - config := make(map[string]*Environment) - err = yaml.Unmarshal(file, config) +func ReadEnvFromFile() (*Environment, error) { + file, err := ioutil.ReadFile(ConfigFile) if err != nil { return nil, err } - return config, nil -} - -func GetEnvironment() (*Environment, error) { - config, err := ReadConfig() + config := make(map[string]*Environment) + err = yaml.Unmarshal(file, config) if err != nil { return nil, err } @@ -65,6 +70,25 @@ func GetEnvironment() (*Environment, error) { return nil, errors.New("No environment: " + ConfigEnvironment) } + return env, nil +} + +func GetEnvironment() (*Environment, error) { + var env *Environment + if ConfigPresent() { + var err error + env, err = ReadEnvFromFile() + if err != nil { + return nil, err + } + } else { + var err error + env, err = ReadEnv() + if err != nil { + return nil, err + } + } + if env.Dialect == "" { return nil, errors.New("No dialect specified") }