diff --git a/cmd/root.go b/cmd/root.go index 3728d48d..c74be95b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -95,7 +95,7 @@ func initConfig() { viper.SetEnvPrefix("ssosync") viper.AutomaticEnv() - for _, e := range []string{"google_admin", "google_credentials", "scim_access_token", "scim_endpoint", "log_level", "log_format", "ignore_users", "ignore_groups"} { + for _, e := range []string{"google_admin", "google_credentials", "scim_access_token", "scim_endpoint", "log_level", "log_format", "ignore_users", "ignore_groups", "allow_groups", "allow_pattern"} { if err := viper.BindEnv(e); err != nil { log.Fatalf(errors.Wrap(err, "cannot bind environment variable").Error()) } @@ -154,6 +154,8 @@ func addFlags(cmd *cobra.Command, cfg *config.Config) { rootCmd.Flags().StringVarP(&cfg.GoogleAdmin, "google-admin", "u", "", "Google Admin Email") rootCmd.Flags().StringSliceVar(&cfg.IgnoreUsers, "ignore-users", []string{}, "ignores these users") rootCmd.Flags().StringSliceVar(&cfg.IgnoreGroups, "ignore-groups", []string{}, "ignores these groups") + rootCmd.Flags().StringSliceVar(&cfg.AllowGroups, "allow-groups", []string{}, "allows only these groups (prefixed with this)") + rootCmd.Flags().StringSliceVar(&cfg.AllowPattern, "allow-pattern", []string{}, "pattern necessary for a user email to be allowed") } func logConfig(cfg *config.Config) { diff --git a/internal/config/config.go b/internal/config/config.go index e464435c..a5cdd4b1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,6 +22,10 @@ type Config struct { IgnoreUsers []string `mapstructure:"ignore_users"` // Ignore groups ... IgnoreGroups []string `mapstructure:"ignore_groups"` + // Allow Groups + AllowGroups []string `mapstructure:"allow_groups"` + // Allow Pattern + AllowPattern []string `mapstructure:"allow_pattern"` } const ( diff --git a/internal/config/secrets.go b/internal/config/secrets.go index ccc2b6c2..2f906970 100644 --- a/internal/config/secrets.go +++ b/internal/config/secrets.go @@ -2,7 +2,8 @@ package config import ( "encoding/base64" - + "os" + log "github.com/sirupsen/logrus" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/secretsmanager" ) @@ -19,24 +20,33 @@ func NewSecrets(svc *secretsmanager.SecretsManager) *Secrets { } } +func (s *Secrets) getSecretByNameOrEnv(name string, envName string) (string, error) { + secretName := name + if len(os.Getenv(envName)) > 0 { + secretName = os.Getenv(envName) + } + log.Debug("Getting Secret Name: ", secretName) + return s.getSecret(secretName) +} + // GoogleAdminEmail ... func (s *Secrets) GoogleAdminEmail() (string, error) { - return s.getSecret("SSOSyncGoogleAdminEmail") + return s.getSecretByNameOrEnv("SSOSyncGoogleAdminEmail", "SSOSYNC_SECRETS_GOOGLE_EMAIL") } // SCIMAccessToken ... func (s *Secrets) SCIMAccessToken() (string, error) { - return s.getSecret("SSOSyncSCIMAccessToken") + return s.getSecretByNameOrEnv("SSOSyncSCIMAccessToken", "SSOSYNC_SECRETS_SCIM_TOKEN") } // SCIMEndpointUrl ... func (s *Secrets) SCIMEndpointUrl() (string, error) { - return s.getSecret("SSOSyncSCIMEndpointUrl") + return s.getSecretByNameOrEnv("SSOSyncSCIMEndpointUrl", "SSOSYNC_SECRETS_SCIM_URL") } // GoogleCredentials ... func (s *Secrets) GoogleCredentials() (string, error) { - return s.getSecret("SSOSyncGoogleCredentials") + return s.getSecretByNameOrEnv("SSOSyncGoogleCredentials", "SSOSYNC_SECRETS_GOOGLE_CREDENTIALS") } func (s *Secrets) getSecret(secretKey string) (string, error) { diff --git a/internal/sync.go b/internal/sync.go index 38483186..d73f5c9c 100644 --- a/internal/sync.go +++ b/internal/sync.go @@ -16,7 +16,9 @@ package internal import ( "context" + "strings" "io/ioutil" + "regexp" "github.com/awslabs/ssosync/internal/aws" "github.com/awslabs/ssosync/internal/config" @@ -100,6 +102,11 @@ func (s *syncGSuite) SyncUsers() error { continue } + if !s.allowPattern(u.PrimaryEmail) { + log.Debug("Filtered out a user") + continue + } + ll := log.WithFields(log.Fields{ "email": u.PrimaryEmail, }) @@ -156,6 +163,10 @@ func (s *syncGSuite) SyncGroups() error { continue } + if ! s.allowGroup(g.Email) { + continue + } + log := log.WithFields(log.Fields{ "group": g.Email, }) @@ -296,3 +307,34 @@ func (s *syncGSuite) ignoreGroup(name string) bool { return false } + +func (s *syncGSuite) allowGroup(name string) bool { + if len(s.cfg.AllowGroups) == 0 { + return true + } + + for _, g := range s.cfg.AllowGroups { + if strings.HasPrefix(name, g) { + return true + } + } + + return false +} + +func (s *syncGSuite) allowPattern(name string) bool { + if len(s.cfg.AllowGroups) == 0 { + return true + } + for _, p := range s.cfg.AllowPattern { + if p == "" { + return true + } + + re := regexp.MustCompile(p) + if re.FindStringIndex(name) != nil { + return true + } + } + return false +}