diff --git a/cli/cmd/ds-load/main.go b/cli/cmd/ds-load/main.go index ac7fca72..d976d1cd 100644 --- a/cli/cmd/ds-load/main.go +++ b/cli/cmd/ds-load/main.go @@ -15,6 +15,7 @@ import ( func main() { pluginEnum := "" + pluginFinder, err := plugin.NewHomeDirFinder(true) if err != nil { os.Stderr.WriteString(err.Error()) @@ -26,9 +27,11 @@ func main() { os.Stderr.WriteString(err.Error()) os.Exit(1) } + for _, p := range plugins { pluginEnum += (p.Name + "|") } + pluginEnum = strings.TrimSuffix(pluginEnum, "|") yamlLoader := kongyaml.NewYAMLResolver("") @@ -52,9 +55,12 @@ func main() { } kongCtx := kong.Parse(&cli, options...) + ctx := cc.NewCommonContext(cli.Verbosity, string(cli.Config)) + if err := kongCtx.Run(ctx); err != nil { kongCtx.FatalIfErrorf(err) } + os.Exit(common.GetExitCode()) } diff --git a/cli/pkg/app/cli.go b/cli/pkg/app/cli.go index 46eefbdd..486747ee 100644 --- a/cli/pkg/app/cli.go +++ b/cli/pkg/app/cli.go @@ -44,6 +44,7 @@ func (listPlugins *ListPluginsCmd) Run(c *cc.CommonCtx) error { if err != nil { return err } + plugins, err := find.Find() if err != nil { return err @@ -52,8 +53,8 @@ func (listPlugins *ListPluginsCmd) Run(c *cc.CommonCtx) error { for _, p := range plugins { os.Stdout.WriteString(p.Name + " " + p.Path + "\n") } - return nil + return nil } type VersionCmd struct{} @@ -63,5 +64,6 @@ func (cmd *VersionCmd) Run(c *cc.CommonCtx) error { constants.AppName, version.GetInfo().String(), ) + return nil } diff --git a/cli/pkg/app/exec.go b/cli/pkg/app/exec.go index 77c34976..202d8faa 100644 --- a/cli/pkg/app/exec.go +++ b/cli/pkg/app/exec.go @@ -32,8 +32,12 @@ type ExecCmd struct { func (e *ExecCmd) Run(c *cc.CommonCtx) error { defaultPrintCmd := []string{"fetch", "version", "export-transform"} - var err error - var find *plugin.Finder + + var ( + err error + find *plugin.Finder + ) + if e.PluginFolder != "" { find = plugin.NewFinder(true, e.PluginFolder) } else { @@ -42,24 +46,29 @@ func (e *ExecCmd) Run(c *cc.CommonCtx) error { return err } } + pl := e.CommandArgs[0] plugins, err := find.Find() if err != nil { return err } + for _, p := range plugins { if pl == p.Name { e.execPlugin = p break } } + if e.execPlugin == nil { return errors.Errorf("plugin [%s] not found", pl) } e.pluginArgs = e.CommandArgs[1:] + var pluginSubCommand string + if len(e.CommandArgs) > 1 { pluginSubCommand = e.CommandArgs[1] } @@ -73,6 +82,7 @@ func (e *ExecCmd) Run(c *cc.CommonCtx) error { if err != nil { return errors.Wrap(err, "Could not connect to the directory") } + e.publisher = publish.NewDirectoryPublisher(c, dirClient) } @@ -83,9 +93,13 @@ func (e *ExecCmd) LaunchPlugin(c *cc.CommonCtx) error { if (!slices.Contains(e.pluginArgs, "-c") || !slices.Contains(e.pluginArgs, "--config")) && c.ConfigPath != "" { e.pluginArgs = append(e.pluginArgs, "-c", c.ConfigPath) } + pluginCmd := exec.Command(e.execPlugin.Path, e.pluginArgs...) //nolint:gosec - var pStdout io.ReadCloser - var wg sync.WaitGroup + + var ( + pStdout io.ReadCloser + wg sync.WaitGroup + ) pStderr, err := pluginCmd.StderrPipe() if err != nil { @@ -94,6 +108,7 @@ func (e *ExecCmd) LaunchPlugin(c *cc.CommonCtx) error { defer pStderr.Close() wg.Add(1) + go listenOnStderr(c, &wg, pStderr) if e.Print { @@ -110,6 +125,7 @@ func (e *ExecCmd) LaunchPlugin(c *cc.CommonCtx) error { if err != nil { return err } + if (fi.Mode() & os.ModeCharDevice) == 0 { pluginCmd.Stdin = os.Stdin } @@ -122,12 +138,14 @@ func (e *ExecCmd) LaunchPlugin(c *cc.CommonCtx) error { if !e.Print { err = e.publisher.Publish(c.Context, pStdout) } + if err != nil { wg.Wait() return err } wg.Wait() + return pluginCmd.Wait() } @@ -154,5 +172,6 @@ func listenOnStderr(c *cc.CommonCtx, wg *sync.WaitGroup, stderr io.ReadCloser) { c.Log.Fatal().Err(err) } } + wg.Done() } diff --git a/cli/pkg/app/publish.go b/cli/pkg/app/publish.go index 87353656..ffadca52 100644 --- a/cli/pkg/app/publish.go +++ b/cli/pkg/app/publish.go @@ -22,6 +22,7 @@ func (l *PublishCmd) Run(commonCtx *cc.CommonCtx) error { if err != nil { return errors.Wrap(err, "Could not connect to the directory") } + publisher = publish.NewDirectoryPublisher(commonCtx, dirClient) return l.processMessagesFromStdIn(commonCtx, publisher) diff --git a/cli/pkg/clients/directory_client.go b/cli/pkg/clients/directory_client.go index caff5de3..e438fdab 100644 --- a/cli/pkg/clients/directory_client.go +++ b/cli/pkg/clients/directory_client.go @@ -71,11 +71,14 @@ func validate(cfg *Config) error { opts := []grpc.DialOption{ grpc.WithUserAgent("ds-load " + version.GetInfo().Version), } + if cfg.Insecure { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } + if _, err := grpcurl.BlockingDial(ctx, "tcp", cfg.Host, creds, opts...); err != nil { return err } + return nil } diff --git a/cli/pkg/plugin/finder.go b/cli/pkg/plugin/finder.go index 63f81592..962826f3 100644 --- a/cli/pkg/plugin/finder.go +++ b/cli/pkg/plugin/finder.go @@ -37,9 +37,11 @@ func NewHomeDirFinder(env bool) (*Finder, error) { func (f Finder) Find() ([]*Plugin, error) { addedPlugins := []string{} dirs := f.dirs + if f.env { pathEnv := os.Getenv("PATH") dirs = append(dirs, strings.Split(pathEnv, string(os.PathListSeparator))...) + pwd, err := os.Getwd() if err != nil { fmt.Println(err) @@ -55,6 +57,7 @@ func (f Finder) Find() ([]*Plugin, error) { if err != nil { return nil, err } + if len(files) > 0 { for _, f := range files { p := NewPlugin(f) diff --git a/cli/pkg/plugin/plugin.go b/cli/pkg/plugin/plugin.go index e78c4b8c..01216e2d 100644 --- a/cli/pkg/plugin/plugin.go +++ b/cli/pkg/plugin/plugin.go @@ -23,8 +23,10 @@ func NewPlugin(path string) *Plugin { func pluginName(path string) string { file := filepath.Base(path) name := strings.TrimPrefix(file, constants.PluginPrefix) + if runtime.GOOS == "windows" { name = strings.TrimSuffix(name, ".exe") } + return name } diff --git a/cli/pkg/publish/publisher_v3.go b/cli/pkg/publish/publisher_v3.go index 28404945..db9d6fc8 100644 --- a/cli/pkg/publish/publisher_v3.go +++ b/cli/pkg/publish/publisher_v3.go @@ -47,13 +47,16 @@ func (p *DirectoryPublisher) Publish(ctx context.Context, reader io.Reader) erro for { var message msg.Transform + err := jsonReader.ReadProtoMessage(&message) if err == io.EOF { break } + if err != nil { return err } + err = p.publishMessages(ctx, &message) if err != nil { return err @@ -63,6 +66,7 @@ func (p *DirectoryPublisher) Publish(ctx context.Context, reader io.Reader) erro if p.objCounter != nil { printCounter(os.Stdout, p.objCounter) } + if p.relCounter != nil { printCounter(os.Stdout, p.relCounter) } @@ -77,11 +81,14 @@ func (p *DirectoryPublisher) Publish(ctx context.Context, reader io.Reader) erro func (p *DirectoryPublisher) publishMessages(ctx context.Context, message *msg.Transform) error { errGroup, iCtx := errgroup.WithContext(ctx) + stream, err := p.importerClient.Import(iCtx) if err != nil { return err } + errGroup.Go(p.receiver(stream)) + errGroup.Go(p.doneHandler(stream.Context())) opCode := message.OpCode @@ -95,9 +102,11 @@ func (p *DirectoryPublisher) publishMessages(ctx context.Context, message *msg.T fmt.Fprintf(os.Stderr, "validation failed, object: [%s] type [%s]\n", object.Id, object.Type) continue } + if (opCode == dsi3.Opcode_OPCODE_DELETE || opCode == dsi3.Opcode_OPCODE_DELETE_WITH_RELATIONS) && object.Type == "group" { continue } + fmt.Fprintf(os.Stdout, "object: [%s] type [%s]\n", object.Id, object.Type) sErr := stream.Send(&dsi3.ImportRequest{ Msg: &dsi3.ImportRequest_Object{ @@ -105,6 +114,7 @@ func (p *DirectoryPublisher) publishMessages(ctx context.Context, message *msg.T }, OpCode: opCode, }) + p.handleStreamError(sErr) } @@ -114,6 +124,7 @@ func (p *DirectoryPublisher) publishMessages(ctx context.Context, message *msg.T fmt.Fprintf(os.Stderr, "validation failed, relation: [%s] obj: [%s] subj [%s]\n", relation.Relation, relation.ObjectId, relation.SubjectId) continue } + fmt.Fprintf(os.Stdout, "relation: [%s] obj: [%s] subj [%s]\n", relation.Relation, relation.ObjectId, relation.SubjectId) sErr := stream.Send(&dsi3.ImportRequest{ Msg: &dsi3.ImportRequest_Relation{ @@ -121,6 +132,7 @@ func (p *DirectoryPublisher) publishMessages(ctx context.Context, message *msg.T }, OpCode: opCode, }) + p.handleStreamError(sErr) } @@ -162,6 +174,7 @@ func (p *DirectoryPublisher) receiver(stream dsi3.Importer_ImportClient) func() switch m := result.Msg.(type) { case *dsi3.ImportResponse_Status: p.errs = true + printStatus(os.Stderr, m.Status) case *dsi3.ImportResponse_Counter: switch m.Counter.Type { @@ -179,11 +192,13 @@ func (p *DirectoryPublisher) receiver(stream dsi3.Importer_ImportClient) func() func (p *DirectoryPublisher) doneHandler(ctx context.Context) func() error { return func() error { <-ctx.Done() + err := ctx.Err() if err != nil && !errors.Is(err, context.Canceled) { p.Log.Trace().Err(err).Msg("subscriber-doneHandler") return err } + return nil } } diff --git a/plugins/auth0/pkg/app/cli.go b/plugins/auth0/pkg/app/cli.go index 0dfef32c..cd63d006 100644 --- a/plugins/auth0/pkg/app/cli.go +++ b/plugins/auth0/pkg/app/cli.go @@ -18,13 +18,13 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/auth0/pkg/app/exec.go b/plugins/auth0/pkg/app/exec.go index b277c767..bf00c58d 100644 --- a/plugins/auth0/pkg/app/exec.go +++ b/plugins/auth0/pkg/app/exec.go @@ -29,12 +29,15 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithUserPID(cmd.UserPID).WithEmail(cmd.UserEmail).WithRoles(cmd.Roles) templateContent, err := cmd.getTemplateContent() if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/auth0/pkg/app/export_transform.go b/plugins/auth0/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/auth0/pkg/app/export_transform.go +++ b/plugins/auth0/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/auth0/pkg/app/fetch.go b/plugins/auth0/pkg/app/fetch.go index 8f402ac4..fe039f5d 100644 --- a/plugins/auth0/pkg/app/fetch.go +++ b/plugins/auth0/pkg/app/fetch.go @@ -36,6 +36,7 @@ func (f *FetchCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithUserPID(f.UserPID).WithEmail(f.UserEmail).WithRoles(f.Roles).WithOrgs(f.Orgs).WithSAML(f.SAML) return fetcher.Fetch(ctx.Context, os.Stdout, os.Stderr) diff --git a/plugins/auth0/pkg/app/transform.go b/plugins/auth0/pkg/app/transform.go index f0c49a58..be5e5ba8 100644 --- a/plugins/auth0/pkg/app/transform.go +++ b/plugins/auth0/pkg/app/transform.go @@ -22,6 +22,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(templateContent) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -36,6 +37,7 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { } templateLoader := template.NewTemplateLoader(templateContent) + templateContent, err = templateLoader.Load(t.Template) if err != nil { return nil, err diff --git a/plugins/auth0/pkg/fetch/fetch.go b/plugins/auth0/pkg/fetch/fetch.go index 60e2b0ba..7889c2db 100644 --- a/plugins/auth0/pkg/fetch/fetch.go +++ b/plugins/auth0/pkg/fetch/fetch.go @@ -82,41 +82,39 @@ func (f *Fetcher) fetchUsers(ctx context.Context, outputWriter *js.JSONArrayWrit users, more, err := f.getUsers(ctx, opts) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } for _, user := range users { res, err := user.MarshalJSON() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + var obj map[string]interface{} - err = json.Unmarshal(res, &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal(res, &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + obj["email_verified"] = user.GetEmailVerified() obj["object_type"] = "user" + if f.Roles { roles, err := f.getUserRoles(ctx, *user.ID) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } else { obj["roles"] = roles } } + if f.Orgs { orgs, err := f.getOrgs(ctx, *user.ID) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } else { obj["orgs"] = orgs } @@ -127,9 +125,11 @@ func (f *Fetcher) fetchUsers(ctx context.Context, outputWriter *js.JSONArrayWrit return err } } + if !more { break } + page++ } @@ -141,35 +141,37 @@ func (f *Fetcher) fetchGroups(ctx context.Context, outputWriter *js.JSONArrayWri for f.Roles { opts := []management.RequestOption{management.Page(page)} + if f.ConnectionName != "" { opts = append(opts, management.Query(f.getConnectionQuery())) } roles, more, err := f.getRoles(ctx, opts) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } for _, role := range roles { res := role.String() + var obj map[string]interface{} - err = json.Unmarshal([]byte(res), &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + + if err := json.Unmarshal([]byte(res), &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + obj["object_type"] = "role" - err = outputWriter.Write(obj) - if err != nil { + if err := outputWriter.Write(obj); err != nil { return err } } + if !more { break } + page++ } @@ -187,34 +189,41 @@ func (f *Fetcher) getUsers(ctx context.Context, opts []management.RequestOption) if err != nil { return nil, false, err } + if user == nil { return nil, false, errors.Wrapf(err, "failed to get user by pid %s", f.UserPID) } + return []*management.User{user}, false, nil - } else if f.UserEmail != "" { + } + + if f.UserEmail != "" { // List only users that have the provided email users, err := f.client.Mgmt.User.ListByEmail(ctx, f.UserEmail) if err != nil { return nil, false, err } + return users, false, nil - } else { - // List all users - if !f.SAML { - userList, err := f.client.Mgmt.User.List(ctx, opts...) - if err != nil { - return nil, false, err - } - return userList.Users, userList.HasNext(), nil - } else { - // Use special SAML user list, to avoid known unmarshal errors, see notes below. - ul := &UserList{} - if err := ListUsers(ctx, f.client.Mgmt, &ul, opts...); err != nil { - return nil, false, err - } - return ul.UserList(), ul.HasNext(), nil + } + + // List all users + if !f.SAML { + userList, err := f.client.Mgmt.User.List(ctx, opts...) + if err != nil { + return nil, false, err } + + return userList.Users, userList.HasNext(), nil + } + + // Use special SAML user list, to avoid known unmarshal errors, see notes below. + ul := &UserList{} + if err := ListUsers(ctx, f.client.Mgmt, &ul, opts...); err != nil { + return nil, false, err } + + return ul.UserList(), ul.HasNext(), nil } func (f *Fetcher) getRoles(ctx context.Context, opts []management.RequestOption) ([]*management.Role, bool, error) { @@ -222,6 +231,7 @@ func (f *Fetcher) getRoles(ctx context.Context, opts []management.RequestOption) if err != nil { return nil, false, err } + if roles == nil { return nil, false, errors.Wrap(err, "failed to get roles") } @@ -241,28 +251,33 @@ func (f *Fetcher) getUserRoles(ctx context.Context, uID string) ([]map[string]in } reqOpts := management.Page(page) + roles, err := f.client.Mgmt.User.Roles(ctx, uID, reqOpts) if err != nil { return nil, err } + for _, role := range roles.Roles { res, err := json.Marshal(role) if err != nil { return nil, err } + var obj map[string]interface{} - err = json.Unmarshal(res, &obj) - if err != nil { + if err := json.Unmarshal(res, &obj); err != nil { return nil, err } + results = append(results, obj) } + if !roles.HasNext() { finished = true } page++ } + return results, nil } @@ -278,22 +293,26 @@ func (f *Fetcher) getOrgs(ctx context.Context, uID string) ([]map[string]interfa } reqOpts := management.Page(page) + orgs, err := f.client.Mgmt.User.Organizations(ctx, uID, reqOpts) if err != nil { return nil, err } + for _, org := range orgs.Organizations { res, err := json.Marshal(org) if err != nil { return nil, err } + var obj map[string]interface{} - err = json.Unmarshal(res, &obj) - if err != nil { + if err := json.Unmarshal(res, &obj); err != nil { return nil, err } + results = append(results, obj) } + if !orgs.HasNext() { finished = true } @@ -308,6 +327,7 @@ func (f *Fetcher) getConnectionQuery() string { if f.ConnectionName == "" { return "" } + return `identities.connection:"` + f.ConnectionName + `"` } @@ -358,11 +378,13 @@ func (u *User) UnmarshalJSON(b []byte) error { } type tTmpUser User + var tmpUser tTmpUser if err := json.Unmarshal(buf, &tmpUser); err != nil { return err } + tmpUser.VerifyEmail = &verified *u = User(tmpUser) @@ -398,7 +420,7 @@ func ListUsers(ctx context.Context, m *management.Management, payload interface{ } if len(responseBody) > 0 && string(responseBody) != "{}" { - if err = json.Unmarshal(responseBody, &payload); err != nil { + if err := json.Unmarshal(responseBody, &payload); err != nil { return fmt.Errorf("failed to unmarshal response payload: %w", err) } } diff --git a/plugins/auth0/pkg/httpclient/client.go b/plugins/auth0/pkg/httpclient/client.go index 4378e457..000ed4a4 100644 --- a/plugins/auth0/pkg/httpclient/client.go +++ b/plugins/auth0/pkg/httpclient/client.go @@ -18,6 +18,7 @@ func (rl RateLimit) Wait() { if rl.ResetTime.Before(time.Now()) { return } + duration := time.Until(rl.ResetTime) time.Sleep(duration) } @@ -37,10 +38,12 @@ func (c *Transport) RoundTrip(r *http.Request) (*http.Response, error) { if err != nil || resp.StatusCode >= 400 { return resp, err } + rl, err := parseRateLimit(resp) if err != nil { return resp, err } + c.rateLimiter = rl return resp, err @@ -63,6 +66,7 @@ func parseRateLimit(resp *http.Response) (*RateLimit, error) { if err != nil { return nil, errors.Wrapf(err, "failed to parse X-RateLimit-Reset header") } + rl.ResetTime = time.Unix(int64(reset), 0) return &rl, nil diff --git a/plugins/auth0/pkg/verify/verifier.go b/plugins/auth0/pkg/verify/verifier.go index dc96d9a4..1a3d68cf 100644 --- a/plugins/auth0/pkg/verify/verifier.go +++ b/plugins/auth0/pkg/verify/verifier.go @@ -16,7 +16,6 @@ func New(ctx context.Context, client *auth0client.Auth0Client) (*Verifier, error return &Verifier{ client: client, }, nil - } func (v *Verifier) Verify(ctx context.Context) error { diff --git a/plugins/azuread/cmd/ds-load-azuread/main.go b/plugins/azuread/cmd/ds-load-azuread/main.go index bdb65bce..bbeb2b07 100644 --- a/plugins/azuread/cmd/ds-load-azuread/main.go +++ b/plugins/azuread/cmd/ds-load-azuread/main.go @@ -33,7 +33,9 @@ func main() { } ctx := cc.NewCommonContext(cli.Verbosity, string(cli.Config)) + kongCtx := kong.Parse(&cli, options...) + if err := kongCtx.Run(ctx); err != nil { kongCtx.FatalIfErrorf(err) } diff --git a/plugins/azuread/pkg/app/cli.go b/plugins/azuread/pkg/app/cli.go index 78c62536..73dc6cb2 100644 --- a/plugins/azuread/pkg/app/cli.go +++ b/plugins/azuread/pkg/app/cli.go @@ -20,14 +20,14 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/azuread/pkg/app/exec.go b/plugins/azuread/pkg/app/exec.go index 26a88d5b..85e2dd82 100644 --- a/plugins/azuread/pkg/app/exec.go +++ b/plugins/azuread/pkg/app/exec.go @@ -27,6 +27,8 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher.WithGroups(cmd.Groups)) } diff --git a/plugins/azuread/pkg/app/export_transform.go b/plugins/azuread/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/azuread/pkg/app/export_transform.go +++ b/plugins/azuread/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/azuread/pkg/app/transform.go b/plugins/azuread/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/azuread/pkg/app/transform.go +++ b/plugins/azuread/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/azuread/pkg/azureclient/azure.go b/plugins/azuread/pkg/azureclient/azure.go index a1f12872..5c52b763 100644 --- a/plugins/azuread/pkg/azureclient/azure.go +++ b/plugins/azuread/pkg/azureclient/azure.go @@ -37,6 +37,7 @@ func NewAzureADClient(ctx context.Context, tenant, clientID, clientSecret string if err != nil { return nil, err } + return c, nil } @@ -52,6 +53,7 @@ func NewAzureADClientWithRefreshToken(ctx context.Context, tenant, clientID, cli if err != nil { return nil, err } + return c, nil } @@ -76,6 +78,7 @@ func (c *AzureADClient) GetUserByEmail(ctx context.Context, email string, groups filter := fmt.Sprintf("userPrincipalName eq '%s'", email) return c.listUsers(ctx, filter, groups) } + return azureadUsers, err } @@ -176,5 +179,6 @@ func (c *AzureADClient) initClient(credential azcore.TokenCredential) error { // Create a Graph client using request adapter c.appClient = msgraphsdk.NewMsgraph(adapter) c.requestAdaptor = adapter + return nil } diff --git a/plugins/azuread/pkg/azureclient/credential.go b/plugins/azuread/pkg/azureclient/credential.go index 7df5b470..c8e83da2 100644 --- a/plugins/azuread/pkg/azureclient/credential.go +++ b/plugins/azuread/pkg/azureclient/credential.go @@ -29,6 +29,7 @@ func NewRefreshTokenCredential(ctx context.Context, tenantID, clientID, clientSe tenantID: tenantID, refreshToken: refreshToken, } + return c, nil } @@ -47,6 +48,7 @@ func (c *RefreshTokenCredential) GetToken(ctx context.Context, options policy.To } req.Header.Add("content-type", "application/x-www-form-urlencoded") + res, err := http.DefaultClient.Do(req) if err != nil { return accessToken, err @@ -54,12 +56,13 @@ func (c *RefreshTokenCredential) GetToken(ctx context.Context, options policy.To // process the response defer res.Body.Close() + var responseData map[string]interface{} + body, _ := io.ReadAll(res.Body) // unmarshal the json into a string map - err = json.Unmarshal(body, &responseData) - if err != nil { + if err := json.Unmarshal(body, &responseData); err != nil { return accessToken, err } @@ -73,5 +76,6 @@ func (c *RefreshTokenCredential) GetToken(ctx context.Context, options policy.To accessToken.Token = responseData["access_token"].(string) expiresIn := int(responseData["expires_in"].(float64)) accessToken.ExpiresOn = time.Now().Add(time.Second * time.Duration(expiresIn)) + return accessToken, nil } diff --git a/plugins/azuread/pkg/fetch/fetch.go b/plugins/azuread/pkg/fetch/fetch.go index 6385d7e9..2a8b012b 100644 --- a/plugins/azuread/pkg/fetch/fetch.go +++ b/plugins/azuread/pkg/fetch/fetch.go @@ -34,35 +34,33 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer if f.Groups { aadGroups, err := f.azureClient.ListGroups(ctx) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } for _, group := range aadGroups { writer := kiota.NewJsonSerializationWriter() + err := group.Serialize(writer) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } + groupBytes, err := writer.GetSerializedContent() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } groupString := "{" + string(groupBytes) + "}" + var obj map[string]interface{} - err = json.Unmarshal([]byte(groupString), &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal([]byte(groupString), &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } - err = jsonWriter.Write(obj) - if err != nil { + + if err := jsonWriter.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } @@ -70,35 +68,33 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer aadUsers, err := f.azureClient.ListUsers(ctx, f.Groups) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } for _, user := range aadUsers { writer := kiota.NewJsonSerializationWriter() + err := user.Serialize(writer) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } + userBytes, err := writer.GetSerializedContent() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } userString := "{" + string(userBytes) + "}" + var obj map[string]interface{} - err = json.Unmarshal([]byte(userString), &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal([]byte(userString), &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } - err = jsonWriter.Write(obj) - if err != nil { + + if err := jsonWriter.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/azuread/pkg/verify/verifier.go b/plugins/azuread/pkg/verify/verifier.go index 4f2daa03..8bdb3bdc 100644 --- a/plugins/azuread/pkg/verify/verifier.go +++ b/plugins/azuread/pkg/verify/verifier.go @@ -16,7 +16,6 @@ func New(ctx context.Context, client *azureclient.AzureADClient) (*Verifier, err return &Verifier{ client: client, }, nil - } func (v *Verifier) WithGroups(groups bool) *Verifier { diff --git a/plugins/azureadb2c/cmd/ds-load-azureadb2c/main.go b/plugins/azureadb2c/cmd/ds-load-azureadb2c/main.go index 8661dc5d..4611700b 100644 --- a/plugins/azureadb2c/cmd/ds-load-azureadb2c/main.go +++ b/plugins/azureadb2c/cmd/ds-load-azureadb2c/main.go @@ -34,6 +34,7 @@ func main() { ctx := cc.NewCommonContext(cli.Verbosity, string(cli.Config)) kongCtx := kong.Parse(&cli, options...) + if err := kongCtx.Run(ctx); err != nil { kongCtx.FatalIfErrorf(err) } diff --git a/plugins/azureadb2c/pkg/app/cli.go b/plugins/azureadb2c/pkg/app/cli.go index c3c158a9..2153cf0f 100644 --- a/plugins/azureadb2c/pkg/app/cli.go +++ b/plugins/azureadb2c/pkg/app/cli.go @@ -20,14 +20,14 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/azureadb2c/pkg/app/exec.go b/plugins/azureadb2c/pkg/app/exec.go index 93de35cf..53e47dda 100644 --- a/plugins/azureadb2c/pkg/app/exec.go +++ b/plugins/azureadb2c/pkg/app/exec.go @@ -27,6 +27,8 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher.WithGroups(cmd.Groups)) } diff --git a/plugins/azureadb2c/pkg/app/export_transform.go b/plugins/azureadb2c/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/azureadb2c/pkg/app/export_transform.go +++ b/plugins/azureadb2c/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/azureadb2c/pkg/app/transform.go b/plugins/azureadb2c/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/azureadb2c/pkg/app/transform.go +++ b/plugins/azureadb2c/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/azureadb2c/pkg/azureclient/azure.go b/plugins/azureadb2c/pkg/azureclient/azure.go index 8f82fdc2..d382ba51 100644 --- a/plugins/azureadb2c/pkg/azureclient/azure.go +++ b/plugins/azureadb2c/pkg/azureclient/azure.go @@ -37,6 +37,7 @@ func NewAzureADClient(ctx context.Context, tenant, clientID, clientSecret string if err != nil { return nil, err } + return c, nil } @@ -52,6 +53,7 @@ func NewAzureADClientWithRefreshToken(ctx context.Context, tenant, clientID, cli if err != nil { return nil, err } + return c, nil } @@ -76,6 +78,7 @@ func (c *AzureADClient) GetUserByEmail(ctx context.Context, email string, groups filter := fmt.Sprintf("userPrincipalName eq '%s'", email) return c.listUsers(ctx, filter, groups) } + return azureadUsers, err } @@ -111,8 +114,10 @@ func (c *AzureADClient) ListGroups(ctx context.Context) ([]models.Groupable, err if err != nil { return false } + group.SetMembers(members) result = append(result, group) + return true }) if err != nil { @@ -217,9 +222,12 @@ func (c *AzureADClient) listUsers(ctx context.Context, filter string, groups boo if err != nil { return false } + user.SetMemberOf(members) } + result = append(result, user) + return true }) if err != nil { @@ -246,5 +254,6 @@ func (c *AzureADClient) initClient(credential azcore.TokenCredential) error { // Create a Graph client using request adapter c.appClient = msgraphsdk.NewMsgraph(adapter) c.requestAdaptor = adapter + return nil } diff --git a/plugins/azureadb2c/pkg/azureclient/credential.go b/plugins/azureadb2c/pkg/azureclient/credential.go index 7df5b470..c8e83da2 100644 --- a/plugins/azureadb2c/pkg/azureclient/credential.go +++ b/plugins/azureadb2c/pkg/azureclient/credential.go @@ -29,6 +29,7 @@ func NewRefreshTokenCredential(ctx context.Context, tenantID, clientID, clientSe tenantID: tenantID, refreshToken: refreshToken, } + return c, nil } @@ -47,6 +48,7 @@ func (c *RefreshTokenCredential) GetToken(ctx context.Context, options policy.To } req.Header.Add("content-type", "application/x-www-form-urlencoded") + res, err := http.DefaultClient.Do(req) if err != nil { return accessToken, err @@ -54,12 +56,13 @@ func (c *RefreshTokenCredential) GetToken(ctx context.Context, options policy.To // process the response defer res.Body.Close() + var responseData map[string]interface{} + body, _ := io.ReadAll(res.Body) // unmarshal the json into a string map - err = json.Unmarshal(body, &responseData) - if err != nil { + if err := json.Unmarshal(body, &responseData); err != nil { return accessToken, err } @@ -73,5 +76,6 @@ func (c *RefreshTokenCredential) GetToken(ctx context.Context, options policy.To accessToken.Token = responseData["access_token"].(string) expiresIn := int(responseData["expires_in"].(float64)) accessToken.ExpiresOn = time.Now().Add(time.Second * time.Duration(expiresIn)) + return accessToken, nil } diff --git a/plugins/azureadb2c/pkg/fetch/fetch.go b/plugins/azureadb2c/pkg/fetch/fetch.go index 4ae70de2..7b7e7bc5 100644 --- a/plugins/azureadb2c/pkg/fetch/fetch.go +++ b/plugins/azureadb2c/pkg/fetch/fetch.go @@ -34,35 +34,33 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer if f.Groups { aadGroups, err := f.azureClient.ListGroups(ctx) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } for _, group := range aadGroups { writer := kiota.NewJsonSerializationWriter() + err := group.Serialize(writer) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } + groupBytes, err := writer.GetSerializedContent() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } groupString := "{" + string(groupBytes) + "}" + var obj map[string]interface{} - err = json.Unmarshal([]byte(groupString), &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal([]byte(groupString), &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } - err = jsonWriter.Write(obj) - if err != nil { + + if err := jsonWriter.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } @@ -70,35 +68,33 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer aadUsers, err := f.azureClient.ListUsers(ctx, f.Groups) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } for _, user := range aadUsers { writer := kiota.NewJsonSerializationWriter() + err := user.Serialize(writer) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } + userBytes, err := writer.GetSerializedContent() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } userString := "{" + string(userBytes) + "}" + var obj map[string]interface{} - err = json.Unmarshal([]byte(userString), &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal([]byte(userString), &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } - err = jsonWriter.Write(obj) - if err != nil { + + if err := jsonWriter.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/azureadb2c/pkg/verify/verifier.go b/plugins/azureadb2c/pkg/verify/verifier.go index a3a6124d..8f3aa7f4 100644 --- a/plugins/azureadb2c/pkg/verify/verifier.go +++ b/plugins/azureadb2c/pkg/verify/verifier.go @@ -17,7 +17,6 @@ func New(ctx context.Context, client *azureclient.AzureADClient) (*Verifier, err return &Verifier{ client: client, }, nil - } func (v *Verifier) WithGroups(groups bool) *Verifier { diff --git a/plugins/cognito/pkg/app/cli.go b/plugins/cognito/pkg/app/cli.go index 3d0812dd..3ad6567b 100644 --- a/plugins/cognito/pkg/app/cli.go +++ b/plugins/cognito/pkg/app/cli.go @@ -18,13 +18,13 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/cognito/pkg/app/exec.go b/plugins/cognito/pkg/app/exec.go index 27fef1bd..a6872ba9 100644 --- a/plugins/cognito/pkg/app/exec.go +++ b/plugins/cognito/pkg/app/exec.go @@ -14,7 +14,6 @@ type ExecCmd struct { } func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { - cognitoClient, err := cognitoclient.NewCognitoClient(cmd.AccessKey, cmd.SecretKey, cmd.UserPoolID, cmd.Region) if err != nil { return err @@ -24,12 +23,15 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups) templateContent, err := cmd.getTemplateContent() if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/cognito/pkg/app/export_transform.go b/plugins/cognito/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/cognito/pkg/app/export_transform.go +++ b/plugins/cognito/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/cognito/pkg/app/fetch.go b/plugins/cognito/pkg/app/fetch.go index 8ab72d7a..974fa013 100644 --- a/plugins/cognito/pkg/app/fetch.go +++ b/plugins/cognito/pkg/app/fetch.go @@ -26,6 +26,7 @@ func (cmd *FetchCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups) return fetcher.Fetch(ctx.Context, os.Stdout, os.Stderr) diff --git a/plugins/cognito/pkg/app/transform.go b/plugins/cognito/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/cognito/pkg/app/transform.go +++ b/plugins/cognito/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/cognito/pkg/cognitoclient/cognito.go b/plugins/cognito/pkg/cognitoclient/cognito.go index 0d94087f..30507001 100644 --- a/plugins/cognito/pkg/cognitoclient/cognito.go +++ b/plugins/cognito/pkg/cognitoclient/cognito.go @@ -37,9 +37,10 @@ func NewCognitoClient(accessKey, secretKey, userPoolID, region string) (*Cognito } func (c *CognitoClient) ListUsers(ctx context.Context) ([]*cognitoidentityprovider.UserType, error) { - users := make([]*cognitoidentityprovider.UserType, 0) var paginationToken *string + users := make([]*cognitoidentityprovider.UserType, 0) + for { listUsersInput := &cognitoidentityprovider.ListUsersInput{ UserPoolId: aws.String(c.userPoolID), diff --git a/plugins/cognito/pkg/fetch/fetch.go b/plugins/cognito/pkg/fetch/fetch.go index 71320d1f..95e88eb9 100644 --- a/plugins/cognito/pkg/fetch/fetch.go +++ b/plugins/cognito/pkg/fetch/fetch.go @@ -6,6 +6,7 @@ import ( "io" "github.com/aserto-dev/ds-load/plugins/cognito/pkg/cognitoclient" + "github.com/aserto-dev/ds-load/sdk/common" "github.com/aserto-dev/ds-load/sdk/common/js" ) @@ -38,17 +39,16 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer for _, group := range groups { groupBytes, err := json.Marshal(group) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } + var obj map[string]interface{} - err = json.Unmarshal(groupBytes, &obj) - if err != nil { + if err := json.Unmarshal(groupBytes, &obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } - err = writer.Write(obj) - if err != nil { + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } @@ -67,14 +67,15 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer userBytes, err := json.Marshal(user) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } + var obj map[string]interface{} - err = json.Unmarshal(userBytes, &obj) - if err != nil { + if err := json.Unmarshal(userBytes, &obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } + obj["Attributes"] = attributes if f.groups { @@ -89,16 +90,16 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer _, _ = errorWriter.Write([]byte(err.Error())) return err } + var grps []map[string]string - err = json.Unmarshal(groupBytes, &grps) - if err != nil { + if err := json.Unmarshal(groupBytes, &grps); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } + obj["Groups"] = grps } - err = writer.Write(obj) - if err != nil { + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/cognito/pkg/verify/verifier.go b/plugins/cognito/pkg/verify/verifier.go index 7ed3b14f..4e884da2 100644 --- a/plugins/cognito/pkg/verify/verifier.go +++ b/plugins/cognito/pkg/verify/verifier.go @@ -15,7 +15,6 @@ func New(ctx context.Context, client *cognitoclient.CognitoClient) (*Verifier, e return &Verifier{ client: client, }, nil - } func (v *Verifier) Verify(ctx context.Context) error { diff --git a/plugins/fusionauth/pkg/app/cli.go b/plugins/fusionauth/pkg/app/cli.go index f58baf71..26494eb0 100644 --- a/plugins/fusionauth/pkg/app/cli.go +++ b/plugins/fusionauth/pkg/app/cli.go @@ -18,13 +18,13 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/fusionauth/pkg/app/exec.go b/plugins/fusionauth/pkg/app/exec.go index e530b95b..3122a501 100644 --- a/plugins/fusionauth/pkg/app/exec.go +++ b/plugins/fusionauth/pkg/app/exec.go @@ -14,7 +14,6 @@ type ExecCmd struct { } func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { - fusionauthClient, err := fusionauthclient.NewFusionAuthClient(cmd.HostURL, cmd.APIKey) if err != nil { return err @@ -24,12 +23,15 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups).WithHost(cmd.HostURL) templateContent, err := cmd.getTemplateContent() if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/fusionauth/pkg/app/export_transform.go b/plugins/fusionauth/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/fusionauth/pkg/app/export_transform.go +++ b/plugins/fusionauth/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/fusionauth/pkg/app/fetch.go b/plugins/fusionauth/pkg/app/fetch.go index b5fbc1f9..baf4b456 100644 --- a/plugins/fusionauth/pkg/app/fetch.go +++ b/plugins/fusionauth/pkg/app/fetch.go @@ -24,6 +24,7 @@ func (cmd *FetchCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups).WithHost(cmd.HostURL) return fetcher.Fetch(ctx.Context, os.Stdout, os.Stderr) diff --git a/plugins/fusionauth/pkg/app/transform.go b/plugins/fusionauth/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/fusionauth/pkg/app/transform.go +++ b/plugins/fusionauth/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/fusionauth/pkg/fetch/fetch.go b/plugins/fusionauth/pkg/fetch/fetch.go index 03cd03bd..a6477945 100644 --- a/plugins/fusionauth/pkg/fetch/fetch.go +++ b/plugins/fusionauth/pkg/fetch/fetch.go @@ -43,23 +43,24 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer for i := range users { user := &users[i] + userBytes, err := json.Marshal(user) if err != nil { _, _ = errorWriter.Write([]byte(err.Error())) return err } + var obj map[string]interface{} - err = json.Unmarshal(userBytes, &obj) - if err != nil { + if err := json.Unmarshal(userBytes, &obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) return err } + if user.ImageUrl != "" { obj["picture"] = fmt.Sprintf("%s%s", f.host, user.ImageUrl) } - err = writer.Write(obj) - if err != nil { + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } @@ -72,8 +73,7 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer for i := range groups { group := &groups[i] - err = writer.Write(group) - if err != nil { + if err := writer.Write(group); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/fusionauth/pkg/fusionauthclient/fusionauth.go b/plugins/fusionauth/pkg/fusionauthclient/fusionauth.go index f7d2a989..17bed314 100644 --- a/plugins/fusionauth/pkg/fusionauthclient/fusionauth.go +++ b/plugins/fusionauth/pkg/fusionauthclient/fusionauth.go @@ -18,14 +18,15 @@ type FusionAuthClient struct { func NewFusionAuthClient(host, apiKey string) (*FusionAuthClient, error) { c := &FusionAuthClient{} - var httpClient = &http.Client{ + httpClient := &http.Client{ Timeout: time.Second * 10, } - var baseURL, _ = url.Parse(host) + baseURL, _ := url.Parse(host) c.fusionauthClient = fusionauth.NewClient(httpClient, baseURL, apiKey) c.host = host + return c, nil } @@ -52,6 +53,7 @@ func (c *FusionAuthClient) ListUsers(ctx context.Context) ([]fusionauth.User, er fmt.Println("Failed to list users:", err) return nil, err } + if faErrs != nil { fmt.Println("Failed to list users:", faErrs) return nil, faErrs @@ -64,6 +66,7 @@ func (c *FusionAuthClient) ListUsers(ctx context.Context) ([]fusionauth.User, er break } } + return users, nil } diff --git a/plugins/fusionauth/pkg/verify/verifier.go b/plugins/fusionauth/pkg/verify/verifier.go index a74ee741..61a78279 100644 --- a/plugins/fusionauth/pkg/verify/verifier.go +++ b/plugins/fusionauth/pkg/verify/verifier.go @@ -15,7 +15,6 @@ func New(ctx context.Context, client *fusionauthclient.FusionAuthClient) (*Verif return &Verifier{ client: client, }, nil - } func (v *Verifier) Verify(ctx context.Context) error { diff --git a/plugins/google/cmd/ds-load-google/main.go b/plugins/google/cmd/ds-load-google/main.go index 61ac09d1..f35cb554 100644 --- a/plugins/google/cmd/ds-load-google/main.go +++ b/plugins/google/cmd/ds-load-google/main.go @@ -36,6 +36,7 @@ func main() { } ctx := cc.NewCommonContext(cli.Verbosity, string(cli.Config)) + kongCtx := kong.Parse(&cli, options...) if err := kongCtx.Run(ctx); err != nil { kongCtx.FatalIfErrorf(err) diff --git a/plugins/google/pkg/app/cli.go b/plugins/google/pkg/app/cli.go index 5cf812e4..6c09b927 100644 --- a/plugins/google/pkg/app/cli.go +++ b/plugins/google/pkg/app/cli.go @@ -19,13 +19,13 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/google/pkg/app/exec.go b/plugins/google/pkg/app/exec.go index 5393f586..2c54565f 100644 --- a/plugins/google/pkg/app/exec.go +++ b/plugins/google/pkg/app/exec.go @@ -23,12 +23,15 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups) templateContent, err := cmd.getTemplateContent() if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/google/pkg/app/export_transform.go b/plugins/google/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/google/pkg/app/export_transform.go +++ b/plugins/google/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/google/pkg/app/fetch.go b/plugins/google/pkg/app/fetch.go index f828593a..ac443052 100644 --- a/plugins/google/pkg/app/fetch.go +++ b/plugins/google/pkg/app/fetch.go @@ -17,7 +17,6 @@ type FetchCmd struct { } func (cmd *FetchCmd) Run(ctx *cc.CommonCtx) error { - gClient, err := googleclient.NewGoogleClient(ctx.Context, cmd.ClientID, cmd.ClientSecret, cmd.RefreshToken, cmd.Customer) if err != nil { return err @@ -27,6 +26,7 @@ func (cmd *FetchCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups) return fetcher.Fetch(ctx.Context, os.Stdout, os.Stderr) diff --git a/plugins/google/pkg/app/get_token.go b/plugins/google/pkg/app/get_token.go index d417319e..08f32404 100644 --- a/plugins/google/pkg/app/get_token.go +++ b/plugins/google/pkg/app/get_token.go @@ -24,5 +24,6 @@ func (cmd *GetTokenCmd) Run(ctx *cc.CommonCtx) error { } fmt.Println("Refresh token: ", refreshToken) + return nil } diff --git a/plugins/google/pkg/app/transform.go b/plugins/google/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/google/pkg/app/transform.go +++ b/plugins/google/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/google/pkg/fetch/fetch.go b/plugins/google/pkg/fetch/fetch.go index a3234173..e289082b 100644 --- a/plugins/google/pkg/fetch/fetch.go +++ b/plugins/google/pkg/fetch/fetch.go @@ -32,28 +32,24 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer users, err := f.gClient.ListUsers() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } for _, user := range users { userBytes, err := json.Marshal(user) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + var obj map[string]interface{} - err = json.Unmarshal(userBytes, &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal(userBytes, &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } - err = writer.Write(obj) - if err != nil { + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } @@ -61,47 +57,41 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer if f.Groups { groups, err := f.gClient.ListGroups() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } for _, group := range groups { groupBytes, err := json.Marshal(group) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + var obj map[string]interface{} - err = json.Unmarshal(groupBytes, &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal(groupBytes, &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } usersInGroup, err := f.gClient.GetUsersInGroup(group.Id) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } else { usersInGroupBytes, err := json.Marshal(usersInGroup) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } else { var users []map[string]interface{} - err = json.Unmarshal(usersInGroupBytes, &users) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal(usersInGroupBytes, &users); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) } + obj["users"] = users } } - err = writer.Write(obj) - if err != nil { + + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/google/pkg/googleclient/google.go b/plugins/google/pkg/googleclient/google.go index e92edf1d..ee1e032f 100644 --- a/plugins/google/pkg/googleclient/google.go +++ b/plugins/google/pkg/googleclient/google.go @@ -42,10 +42,14 @@ func GetRefreshToken(ctx context.Context, clientID, clientSecret string, port in // Create an HTTP server for handling the OAuth 2.0 callback server := &http.Server{Addr: fmt.Sprintf(":%d", port), ReadHeaderTimeout: 5 * time.Second} + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") + fmt.Fprintf(w, "Authorization code received. You can close this tab now.") + authCode = code + go func() { // Shutdown the HTTP server once the callback is received if err := server.Shutdown(ctx); err != nil { @@ -93,6 +97,7 @@ func NewGoogleClient(ctx context.Context, clientID, clientSecret, refreshToken, c.googleClient = svc c.customer = customer + return c, nil } @@ -105,6 +110,7 @@ func (c *GoogleClient) ListUsers() ([]*admin.User, error) { if err != nil { return nil, err } + users = append(users, response.Users...) if response.NextPageToken == "" { @@ -126,6 +132,7 @@ func (c *GoogleClient) ListGroups() ([]*admin.Group, error) { if err != nil { return nil, err } + groups = append(groups, response.Groups...) if response.NextPageToken == "" { @@ -147,6 +154,7 @@ func (c *GoogleClient) GetUsersInGroup(group string) ([]*admin.Member, error) { if err != nil { return nil, err } + members = append(members, response.Members...) if response.NextPageToken == "" { diff --git a/plugins/google/pkg/verify/verifier.go b/plugins/google/pkg/verify/verifier.go index 972952e8..745a3b22 100644 --- a/plugins/google/pkg/verify/verifier.go +++ b/plugins/google/pkg/verify/verifier.go @@ -15,7 +15,6 @@ func New(ctx context.Context, client *googleclient.GoogleClient) (*Verifier, err return &Verifier{ client: client, }, nil - } func (v *Verifier) Verify(ctx context.Context) error { diff --git a/plugins/jumpcloud/cmd/ds-load-jumpcloud/main.go b/plugins/jumpcloud/cmd/ds-load-jumpcloud/main.go index 11902614..3f399f1c 100644 --- a/plugins/jumpcloud/cmd/ds-load-jumpcloud/main.go +++ b/plugins/jumpcloud/cmd/ds-load-jumpcloud/main.go @@ -37,6 +37,7 @@ func main() { ctx := cc.NewCommonContext(cli.Verbosity, string(cli.Config)) kongCtx := kong.Parse(&cli, options...) + if err := kongCtx.Run(ctx); err != nil { kongCtx.FatalIfErrorf(err) } diff --git a/plugins/jumpcloud/pkg/app/assets/transform_template.tmpl b/plugins/jumpcloud/pkg/app/assets/transform_template.tmpl index 7105c5f5..c03f72e6 100644 --- a/plugins/jumpcloud/pkg/app/assets/transform_template.tmpl +++ b/plugins/jumpcloud/pkg/app/assets/transform_template.tmpl @@ -11,30 +11,34 @@ "type": "user", "display_name": "{{ $.firstname }} {{ $.middlename -}} {{ $.lastname }}", "properties": { - {{ fromEnv "connection_id" "ASERTO_CONNECTION_ID" }}, + "enabled": "{{ not $.account_locked }}", "email": "{{ $.email }}", - "organization_id": "{{ $.organization }}", - "status": "{{ $status }}", "user_id": "{{ $.id }}", - "username": "{{ $.username }}" + "username": "{{ $.username }}", + "manager": "{{ $.manager }}", + "organization": "{{ $.company }}", + "department": "{{ $.department }}", + "title": "{{ $.jobTitle }}", + {{ range $i, $attr := $.attributes }} + {{ if eq $attr.name "roles" }} + "{{ $attr.name }}": {{ splitList "," $attr.value | marshal }}, + {{ else }} + "{{ $attr.name }}": "{{ $attr.value }}", + {{ end }} + {{ end }} + "status": "{{ $status }}" }, "created_at": "{{ $.created }}" }, { "id": "{{ $.email }}", "type": "identity", - "display_name": "{{ $.firstname }} {{ $.middlename -}} {{ $.lastname }} (email)", - "properties": { - {{ fromEnv "connection_id" "ASERTO_CONNECTION_ID" }} - } + "display_name": "{{ $.firstname }} {{ $.middlename -}} {{ $.lastname }} (email)" }, { "id": "{{ $.username }}", "type": "identity", - "display_name": "{{ $.firstname }} {{ $.middlename -}} {{ $.lastname }} (username)", - "properties": { - {{ fromEnv "connection_id" "ASERTO_CONNECTION_ID" }} - } + "display_name": "{{ $.firstname }} {{ $.middlename -}} {{ $.lastname }} (username)" } {{ end }} @@ -42,10 +46,7 @@ { "id": "{{ $.name }}", "type": "group", - "display_name": "{{ $.name }}", - "properties": { - {{ fromEnv "connection_id" "ASERTO_CONNECTION_ID" }} - } + "display_name": "{{ $.name }}" } {{ end }} ], @@ -65,6 +66,15 @@ "subject_type": "identity", "subject_id": "{{ $.username }}" } + {{ if $.manager }} + ,{ + "object_type": "user", + "object_id": "{{ $.id }}", + "relation": "manager", + "subject_type": "user", + "subject_id": "{{ $.manager }}" + } + {{ end }} {{ end }} {{ if eq $.type "user_group" }} diff --git a/plugins/jumpcloud/pkg/app/cli.go b/plugins/jumpcloud/pkg/app/cli.go index 66d83b64..e1600ac1 100644 --- a/plugins/jumpcloud/pkg/app/cli.go +++ b/plugins/jumpcloud/pkg/app/cli.go @@ -25,5 +25,6 @@ func (cmd *VersionCmd) Run() error { AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/jumpcloud/pkg/app/exec.go b/plugins/jumpcloud/pkg/app/exec.go index 837dffcc..18f8f784 100644 --- a/plugins/jumpcloud/pkg/app/exec.go +++ b/plugins/jumpcloud/pkg/app/exec.go @@ -23,12 +23,15 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups) templateContent, err := cmd.getTemplateContent() if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/jumpcloud/pkg/app/export_transform.go b/plugins/jumpcloud/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/jumpcloud/pkg/app/export_transform.go +++ b/plugins/jumpcloud/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/jumpcloud/pkg/app/fetch.go b/plugins/jumpcloud/pkg/app/fetch.go index 22b9329b..e975fed6 100644 --- a/plugins/jumpcloud/pkg/app/fetch.go +++ b/plugins/jumpcloud/pkg/app/fetch.go @@ -23,6 +23,7 @@ func (cmd *FetchCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithGroups(cmd.Groups) return fetcher.Fetch(ctx.Context, os.Stdout, os.Stderr) diff --git a/plugins/jumpcloud/pkg/app/transform.go b/plugins/jumpcloud/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/jumpcloud/pkg/app/transform.go +++ b/plugins/jumpcloud/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/jumpcloud/pkg/fetch/fetch.go b/plugins/jumpcloud/pkg/fetch/fetch.go index d2784ff2..76ca7a35 100644 --- a/plugins/jumpcloud/pkg/fetch/fetch.go +++ b/plugins/jumpcloud/pkg/fetch/fetch.go @@ -32,8 +32,7 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer users, err := f.jcc.ListUsers(ctx) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } @@ -42,20 +41,18 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer for _, user := range users { userBytes, err := json.Marshal(user) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + var obj map[string]interface{} - err = json.Unmarshal(userBytes, &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + + if err := json.Unmarshal(userBytes, &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } - err = writer.Write(obj) - if err != nil { + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } @@ -65,47 +62,41 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer if f.Groups { groups, err := f.jcc.ListGroups(ctx, jc.UserGroups) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } for _, group := range groups { groupBytes, err := json.Marshal(group) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } + var obj map[string]interface{} - err = json.Unmarshal(groupBytes, &obj) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal(groupBytes, &obj); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) continue } usersInGroup, err := f.jcc.ExpandUsersInGroup(ctx, group.ID, idLookup) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } else { usersInGroupBytes, err := json.Marshal(usersInGroup) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } else { var users []map[string]interface{} - err = json.Unmarshal(usersInGroupBytes, &users) - if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + if err := json.Unmarshal(usersInGroupBytes, &users); err != nil { + common.WriteErrorWithExitCode(errorWriter, err, 1) } + obj["users"] = users } } - err = writer.Write(obj) - if err != nil { + + if err := writer.Write(obj); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/jumpcloud/pkg/jc/jc.go b/plugins/jumpcloud/pkg/jc/jc.go index 8d7e5d81..e760f609 100644 --- a/plugins/jumpcloud/pkg/jc/jc.go +++ b/plugins/jumpcloud/pkg/jc/jc.go @@ -19,6 +19,7 @@ import ( const ( baseURL string = "https://console.jumpcloud.com/api" apiKeyHeader string = "x-api-key" + batchSize int = 50 ) type JumpCloudClient struct { @@ -83,38 +84,18 @@ const ( ) func (c *JumpCloudClient) ListGroups(ctx context.Context, groupType GroupType) ([]*Group, error) { - var url string + var fullURL string + switch groupType { case AllGroups: - url = baseURL + "/v2/groups" + fullURL = baseURL + "/v2/groups" case SystemGroups: - url = baseURL + "/v2/groups?filter=type:eq:system_group" + fullURL = baseURL + "/v2/groups?filter=type:eq:system_group" case UserGroups: - url = baseURL + "/v2/groups?filter=type:eq:user_group" + fullURL = baseURL + "/v2/groups?filter=type:eq:user_group" } - groups := []*Group{} - - if err := makeHTTPRequest(ctx, url, http.MethodGet, c.headers, nil, nil, &groups); err != nil { - return nil, err - } - - lo.ForEach(groups, func(item *Group, index int) { item.Name = strings.ReplaceAll(item.Name, " ", "_") }) - - return groups, nil -} - -const batchSize int = 10 - -type Members []struct { - To struct { - Type string `json:"type"` - ID string `json:"id"` - } `json:"to"` -} - -func (c *JumpCloudClient) GetUsersInGroup(ctx context.Context, groupID string) ([]*BaseUser, error) { - u, err := url.Parse(baseURL + "/v2/usergroups/" + groupID + "/members") + u, err := url.Parse(fullURL) if err != nil { return nil, err } @@ -125,46 +106,35 @@ func (c *JumpCloudClient) GetUsersInGroup(ctx context.Context, groupID string) ( u.RawQuery = qv.Encode() - members := []struct { - To struct { - ID string `json:"id"` - Type string `json:"type"` - Attributes any `json:"attributes"` - } - Attributes any `json:"attributes"` - }{} - - idList := []string{} + groups := []*Group{} for { - if err := makeHTTPRequest(ctx, u.String(), http.MethodGet, c.headers, nil, nil, &members); err != nil { + resp := []*Group{} + if err := makeHTTPRequest(ctx, u.String(), http.MethodGet, c.headers, nil, nil, &resp); err != nil { return nil, err } - for _, v := range members { - idList = append(idList, v.To.ID) - } + lo.ForEach(resp, func(item *Group, index int) { item.Name = strings.ReplaceAll(item.Name, " ", "_") }) - if len(members) != batchSize { + groups = append(groups, resp...) + + if len(resp) != batchSize { break } qv := u.Query() - qv.Set("skip", strconv.FormatInt(int64(len(idList)), 10)) + qv.Set("skip", strconv.FormatInt(int64(len(groups)), 10)) u.RawQuery = qv.Encode() } - users := []*BaseUser{} - - for _, id := range idList { - user, err := c.GetBaseUserByID(ctx, id) - if err != nil { - return nil, err - } - users = append(users, user) - } + return groups, nil +} - return users, nil +type Members []struct { + To struct { + Type string `json:"type"` + ID string `json:"id"` + } `json:"to"` } func (c *JumpCloudClient) ExpandUsersInGroup(ctx context.Context, groupID string, idLookup map[string]*BaseUser) ([]*BaseUser, error) { @@ -218,6 +188,7 @@ func (c *JumpCloudClient) ExpandUsersInGroup(ctx context.Context, groupID string if err != nil { return nil, err } + users = append(users, user) } } @@ -272,6 +243,7 @@ func makeHTTPRequest[T any](ctx context.Context, reqURL, method string, headers for k, v := range queryParams { q.Set(k, strings.Join(v, ",")) } + u.RawQuery = q.Encode() } @@ -303,8 +275,7 @@ func makeHTTPRequest[T any](ctx context.Context, reqURL, method string, headers return errors.Wrapf(ErrStatusNotOK, "req: %s status: %s response: %s", u.String(), res.Status, buf) } - err = json.Unmarshal(buf, &resp) - if err != nil { + if err := json.Unmarshal(buf, &resp); err != nil { return err } diff --git a/plugins/jumpcloud/pkg/jc/jc_test.go b/plugins/jumpcloud/pkg/jc/jc_test.go index 96825e3e..57e8711b 100644 --- a/plugins/jumpcloud/pkg/jc/jc_test.go +++ b/plugins/jumpcloud/pkg/jc/jc_test.go @@ -6,10 +6,8 @@ import ( "fmt" "os" "testing" - "time" "github.com/aserto-dev/ds-load/plugins/jumpcloud/pkg/jc" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,12 +24,11 @@ func TestMain(m *testing.M) { } func TestListDirectories(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) require.NoError(t, err) - assert.NoError(t, err) directories, err := jcc.ListDirectories(ctx) require.NoError(t, err) @@ -39,90 +36,75 @@ func TestListDirectories(t *testing.T) { enc := json.NewEncoder(os.Stderr) enc.SetEscapeHTML(false) enc.SetIndent("", " ") + if err := enc.Encode(directories); err != nil { require.NoError(t, err) } } func TestListUsers(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) require.NoError(t, err) - assert.NoError(t, err) users, err := jcc.ListUsers(ctx) require.NoError(t, err) + enc := json.NewEncoder(os.Stderr) enc.SetEscapeHTML(false) enc.SetIndent("", " ") + if err := enc.Encode(users); err != nil { require.NoError(t, err) } } func TestListGroups(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) require.NoError(t, err) - assert.NoError(t, err) groups, err := jcc.ListGroups(ctx, jc.AllGroups) require.NoError(t, err) + enc := json.NewEncoder(os.Stderr) enc.SetEscapeHTML(false) enc.SetIndent("", " ") + if err := enc.Encode(groups); err != nil { require.NoError(t, err) } } func TestGetSystemGroups(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) require.NoError(t, err) - assert.NoError(t, err) groups, err := jcc.ListGroups(ctx, jc.SystemGroups) require.NoError(t, err) - enc := json.NewEncoder(os.Stderr) - enc.SetEscapeHTML(false) - enc.SetIndent("", " ") - if err := enc.Encode(groups); err != nil { - require.NoError(t, err) - } -} - -func TestGetUserGroups(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) - require.NoError(t, err) - assert.NoError(t, err) - groups, err := jcc.ListGroups(ctx, jc.UserGroups) - require.NoError(t, err) enc := json.NewEncoder(os.Stderr) enc.SetEscapeHTML(false) enc.SetIndent("", " ") + if err := enc.Encode(groups); err != nil { require.NoError(t, err) } } -func TestGetMembersOfGroup(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) +func TestGetUserGroups(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) defer cancel() jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) require.NoError(t, err) - assert.NoError(t, err) groups, err := jcc.ListGroups(ctx, jc.UserGroups) require.NoError(t, err) @@ -131,17 +113,13 @@ func TestGetMembersOfGroup(t *testing.T) { enc.SetEscapeHTML(false) enc.SetIndent("", " ") - for _, group := range groups { - users, err := jcc.GetUsersInGroup(ctx, group.ID) + if err := enc.Encode(groups); err != nil { require.NoError(t, err) - if err := enc.Encode(users); err != nil { - require.NoError(t, err) - } } } func TestExpandMembersOfGroup(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() jcc, err := jc.NewJumpCloudClient(ctx, os.Getenv("JC_API_KEY")) @@ -165,6 +143,7 @@ func TestExpandMembersOfGroup(t *testing.T) { for _, group := range groups { users, err := jcc.ExpandUsersInGroup(ctx, group.ID, idLookup) require.NoError(t, err) + if err := enc.Encode(users); err != nil { require.NoError(t, err) } diff --git a/plugins/ldap/pkg/app/cli.go b/plugins/ldap/pkg/app/cli.go index d93b6a3f..c547fb5f 100644 --- a/plugins/ldap/pkg/app/cli.go +++ b/plugins/ldap/pkg/app/cli.go @@ -18,13 +18,13 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/ldap/pkg/app/exec.go b/plugins/ldap/pkg/app/exec.go index 47318af0..bd1111a8 100644 --- a/plugins/ldap/pkg/app/exec.go +++ b/plugins/ldap/pkg/app/exec.go @@ -40,6 +40,8 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/ldap/pkg/app/export_transform.go b/plugins/ldap/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/ldap/pkg/app/export_transform.go +++ b/plugins/ldap/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/ldap/pkg/app/transform.go b/plugins/ldap/pkg/app/transform.go index f0c49a58..be5e5ba8 100644 --- a/plugins/ldap/pkg/app/transform.go +++ b/plugins/ldap/pkg/app/transform.go @@ -22,6 +22,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(templateContent) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -36,6 +37,7 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { } templateLoader := template.NewTemplateLoader(templateContent) + templateContent, err = templateLoader.Load(t.Template) if err != nil { return nil, err diff --git a/plugins/ldap/pkg/attribute/attribute.go b/plugins/ldap/pkg/attribute/attribute.go index fe45ba32..6940a187 100644 --- a/plugins/ldap/pkg/attribute/attribute.go +++ b/plugins/ldap/pkg/attribute/attribute.go @@ -38,7 +38,8 @@ func addMembersByType(attributes map[string][]string, userDnTOKey, groupDnTOKey } func decodeAttributes(ldapEntry *ldap.Entry) map[string][]string { - var data = make(map[string][]string) + data := make(map[string][]string) + for _, attribute := range ldapEntry.Attributes { if attribute.Name == "objectSid" { data[attribute.Name] = []string{ObjectSid(ldapEntry)} @@ -52,6 +53,7 @@ func decodeAttributes(ldapEntry *ldap.Entry) map[string][]string { data[attribute.Name] = attribute.Values } + return data } @@ -63,6 +65,7 @@ func ObjectSid(entry *ldap.Entry) string { if len(rawObjectSid) > 0 { return objectsid.Decode(rawObjectSid).String() } + return "" } @@ -73,12 +76,15 @@ func ObjectGUID(entry *ldap.Entry) string { rawObjectGUID := entry.GetRawAttributeValue("objectGUID") if len(rawObjectGUID) > 0 { objectGUID := entry.GetRawAttributeValue("objectGUID") + uuidString, err := uuid.FromBytes(objectGUID) if err != nil { return "" } + return uuidToComStyle(uuidString.String()) } + return "" } @@ -89,6 +95,7 @@ func ObjectGUID(entry *ldap.Entry) string { */ func uuidToComStyle(token string) string { token = strings.ReplaceAll(token, "-", "") + return fmt.Sprintf("%s%s%s%s-%s%s-%s%s-%s%s-%s%s%s%s%s%s", token[6:8], token[4:6], token[2:4], token[0:2], token[10:12], token[8:10], token[14:16], token[12:14], token[16:18], token[18:20], token[20:22], token[22:24], token[24:26], token[26:28], token[28:30], token[30:32]) diff --git a/plugins/ldap/pkg/fetch/fetch.go b/plugins/ldap/pkg/fetch/fetch.go index d862e813..b53de6cd 100644 --- a/plugins/ldap/pkg/fetch/fetch.go +++ b/plugins/ldap/pkg/fetch/fetch.go @@ -76,10 +76,11 @@ func entryType(ldapEntry *ldap.Entry, groupDnToKey map[string]string) string { } func buildMapFromDNToKey(ldapEntries []*ldap.Entry, key string) map[string]string { - var mapDNToKey = make(map[string]string) + mapDNToKey := make(map[string]string) for _, entry := range ldapEntries { mapDNToKey[entry.DN] = extractKey(key, entry) } + return mapDNToKey } diff --git a/plugins/ldap/pkg/ldapclient/ldapclient.go b/plugins/ldap/pkg/ldapclient/ldapclient.go index e85fe7e0..f545e13b 100644 --- a/plugins/ldap/pkg/ldapclient/ldapclient.go +++ b/plugins/ldap/pkg/ldapclient/ldapclient.go @@ -37,6 +37,7 @@ func NewLDAPClient(credentials *Credentials, conOptions *ConnectionOptions, logg if err != nil { return nil, err } + ldapClient.ldapConn = ldapConn ldapClient.credentials = credentials ldapClient.conOptions = conOptions @@ -94,6 +95,7 @@ func (l *LDAPClient) search(filter string) []*ldap.Entry { attributes, nil, ) + sr, err := l.ldapConn.SearchWithPaging(searchRequest, 1000) if err != nil { log.Fatal(err) diff --git a/plugins/okta/cmd/ds-load-okta/main.go b/plugins/okta/cmd/ds-load-okta/main.go index e0d71f0d..a10c524a 100644 --- a/plugins/okta/cmd/ds-load-okta/main.go +++ b/plugins/okta/cmd/ds-load-okta/main.go @@ -34,6 +34,7 @@ func main() { ctx := cc.NewCommonContext(cli.Verbosity, string(cli.Config)) kongCtx := kong.Parse(&cli, options...) + if err := kongCtx.Run(ctx); err != nil { kongCtx.FatalIfErrorf(err) } diff --git a/plugins/okta/pkg/app/cli.go b/plugins/okta/pkg/app/cli.go index c7cdab59..5146d2c6 100644 --- a/plugins/okta/pkg/app/cli.go +++ b/plugins/okta/pkg/app/cli.go @@ -18,13 +18,13 @@ type CLI struct { Verify VerifyCmd `cmd:"verify" help:"verify fetcher configuration and credentials"` } -type VersionCmd struct { -} +type VersionCmd struct{} func (cmd *VersionCmd) Run() error { fmt.Printf("%s - %s\n", AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/okta/pkg/app/exec.go b/plugins/okta/pkg/app/exec.go index 2ec2a93f..c9aaacec 100644 --- a/plugins/okta/pkg/app/exec.go +++ b/plugins/okta/pkg/app/exec.go @@ -30,6 +30,8 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/okta/pkg/app/export_transform.go b/plugins/okta/pkg/app/export_transform.go index 4c0655cd..3c28331c 100644 --- a/plugins/okta/pkg/app/export_transform.go +++ b/plugins/okta/pkg/app/export_transform.go @@ -7,14 +7,14 @@ import ( "github.com/aserto-dev/ds-load/sdk/transform" ) -type ExportTransformCmd struct { -} +type ExportTransformCmd struct{} func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { templateContent, err := Assets().ReadFile("assets/transform_template.tmpl") if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/okta/pkg/app/transform.go b/plugins/okta/pkg/app/transform.go index eb246a02..50d46b93 100644 --- a/plugins/okta/pkg/app/transform.go +++ b/plugins/okta/pkg/app/transform.go @@ -25,6 +25,7 @@ func (t *TransformCmd) Run(kongContext *kong.Context) error { defer cancel() goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(timeoutCtx, goTemplateTransformer) } @@ -33,8 +34,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -46,5 +50,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/okta/pkg/fetch/fetch.go b/plugins/okta/pkg/fetch/fetch.go index fdf01e9d..efdeb686 100644 --- a/plugins/okta/pkg/fetch/fetch.go +++ b/plugins/okta/pkg/fetch/fetch.go @@ -51,8 +51,7 @@ func (fetcher *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io. func (fetcher *Fetcher) fetchUsers(ctx context.Context, writer *js.JSONArrayWriter, errorWriter io.Writer) error { users, response, err := fetcher.oktaClient.User.ListUsers(ctx).Execute() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } @@ -61,11 +60,12 @@ func (fetcher *Fetcher) fetchUsers(ctx context.Context, writer *js.JSONArrayWrit for i := range users { user := &users[i] + userResult, err := fetcher.processUser(ctx, user, errorWriter) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } + err = writer.Write(userResult) if err != nil { _, _ = errorWriter.Write([]byte(err.Error())) @@ -75,8 +75,7 @@ func (fetcher *Fetcher) fetchUsers(ctx context.Context, writer *js.JSONArrayWrit if response != nil && response.HasNextPage() { response, err = response.Next(&users) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } } else { break @@ -89,8 +88,7 @@ func (fetcher *Fetcher) fetchUsers(ctx context.Context, writer *js.JSONArrayWrit func (fetcher *Fetcher) fetchGroups(ctx context.Context, writer *js.JSONArrayWriter, errorWriter io.Writer) error { groups, response, err := fetcher.oktaClient.Group.ListGroups(ctx).Execute() if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) return err } @@ -100,9 +98,9 @@ func (fetcher *Fetcher) fetchGroups(ctx context.Context, writer *js.JSONArrayWri for _, group := range groups { groupResult, err := fetcher.processGroup(ctx, &group, errorWriter) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } + err = writer.Write(groupResult) if err != nil { _, _ = errorWriter.Write([]byte(err.Error())) @@ -112,8 +110,7 @@ func (fetcher *Fetcher) fetchGroups(ctx context.Context, writer *js.JSONArrayWri if response != nil && response.HasNextPage() { response, err = response.Next(&groups) if err != nil { - _, _ = errorWriter.Write([]byte(err.Error())) - common.SetExitCode(1) + common.WriteErrorWithExitCode(errorWriter, err, 1) } } else { break @@ -129,9 +126,10 @@ func (fetcher *Fetcher) processUser(ctx context.Context, user *okta.User, errorW common.SetExitCode(1) return nil, err } + var userResult map[string]interface{} - err = json.Unmarshal(userBytes, &userResult) - if err != nil { + + if err := json.Unmarshal(userBytes, &userResult); err != nil { common.SetExitCode(1) return nil, err } @@ -143,6 +141,7 @@ func (fetcher *Fetcher) processUser(ctx context.Context, user *okta.User, errorW common.SetExitCode(1) return nil, err } + userResult["groups"] = groups } @@ -153,8 +152,10 @@ func (fetcher *Fetcher) processUser(ctx context.Context, user *okta.User, errorW common.SetExitCode(1) return nil, err } + userResult["roles"] = roles } + return userResult, nil } @@ -164,9 +165,10 @@ func (fetcher *Fetcher) processGroup(ctx context.Context, group *okta.Group, err common.SetExitCode(1) return nil, err } + var groupResult map[string]interface{} - err = json.Unmarshal(userBytes, &groupResult) - if err != nil { + + if err := json.Unmarshal(userBytes, &groupResult); err != nil { common.SetExitCode(1) return nil, err } @@ -178,16 +180,20 @@ func (fetcher *Fetcher) processGroup(ctx context.Context, group *okta.Group, err common.SetExitCode(1) return nil, err } + groupResult["roles"] = roles } + return groupResult, nil } func (fetcher *Fetcher) getGroups(ctx context.Context, userID string, errorWriter io.Writer) ([]map[string]interface{}, error) { - var response *okta.APIResponse - var result []map[string]interface{} - var groups []okta.Group - var err error + var ( + response *okta.APIResponse + result []map[string]interface{} + groups []okta.Group + err error + ) groups, response, err = fetcher.oktaClient.User.ListUserGroups(ctx, userID).Execute() if err != nil { @@ -202,11 +208,12 @@ func (fetcher *Fetcher) getGroups(ctx context.Context, userID string, errorWrite if err != nil { return nil, err } + var obj map[string]interface{} - err = json.Unmarshal(groupBytes, &obj) - if err != nil { + if err := json.Unmarshal(groupBytes, &obj); err != nil { return nil, err } + result = append(result, obj) } @@ -219,14 +226,17 @@ func (fetcher *Fetcher) getGroups(ctx context.Context, userID string, errorWrite break } } + return result, nil } func (fetcher *Fetcher) getUserRoles(ctx context.Context, userID string, errorWriter io.Writer) ([]map[string]interface{}, error) { - var response *okta.APIResponse - var result []map[string]interface{} - var roles []okta.Role - var err error + var ( + response *okta.APIResponse + result []map[string]interface{} + roles []okta.Role + err error + ) roles, response, err = fetcher.oktaClient.RoleAssignments.ListAssignedRolesForUser(ctx, userID).Execute() if err != nil { @@ -241,11 +251,13 @@ func (fetcher *Fetcher) getUserRoles(ctx context.Context, userID string, errorWr if err != nil { return nil, err } + var obj map[string]interface{} - err = json.Unmarshal(roleBytes, &obj) - if err != nil { + + if err := json.Unmarshal(roleBytes, &obj); err != nil { return nil, err } + result = append(result, obj) } @@ -258,14 +270,17 @@ func (fetcher *Fetcher) getUserRoles(ctx context.Context, userID string, errorWr break } } + return result, nil } func (fetcher *Fetcher) getGroupRoles(ctx context.Context, groupID string, errorWriter io.Writer) ([]map[string]interface{}, error) { - var response *okta.APIResponse - var result []map[string]interface{} - var roles []okta.Role - var err error + var ( + response *okta.APIResponse + result []map[string]interface{} + roles []okta.Role + err error + ) roles, response, err = fetcher.oktaClient.RoleAssignments.ListGroupAssignedRoles(ctx, groupID).Execute() if err != nil { @@ -280,11 +295,12 @@ func (fetcher *Fetcher) getGroupRoles(ctx context.Context, groupID string, error if err != nil { return nil, err } + var obj map[string]interface{} - err = json.Unmarshal(roleBytes, &obj) - if err != nil { + if err := json.Unmarshal(roleBytes, &obj); err != nil { return nil, err } + result = append(result, obj) } @@ -297,6 +313,7 @@ func (fetcher *Fetcher) getGroupRoles(ctx context.Context, groupID string, error break } } + return result, nil } diff --git a/plugins/okta/pkg/oktaclient/okta.go b/plugins/okta/pkg/oktaclient/okta.go index 0e703827..aee368fb 100644 --- a/plugins/okta/pkg/oktaclient/okta.go +++ b/plugins/okta/pkg/oktaclient/okta.go @@ -23,12 +23,12 @@ func NewOktaClient(domain, token string, requestTimeout int64) (*OktaClient, err okta.WithRateLimitMaxBackOff(30), okta.WithRateLimitMaxRetries(3), ) - if err != nil { return nil, status.Errorf(codes.Internal, "failed to create Okta configuration: %s", err.Error()) } client := okta.NewAPIClient(config) + return &OktaClient{ User: client.UserAPI, Group: client.GroupAPI, diff --git a/plugins/okta/pkg/verify/verifier.go b/plugins/okta/pkg/verify/verifier.go index 2f99ddee..964c7d1d 100644 --- a/plugins/okta/pkg/verify/verifier.go +++ b/plugins/okta/pkg/verify/verifier.go @@ -15,12 +15,10 @@ func New(ctx context.Context, client *oktaclient.OktaClient) (*Verifier, error) return &Verifier{ client: client, }, nil - } func (v *Verifier) Verify(ctx context.Context) error { _, _, err := v.client.User.ListUsers(ctx).Limit(1).Execute() - if err != nil { return errors.Wrap(err, "failed to retrieve user from Okta") } diff --git a/plugins/openapi/pkg/app/cli.go b/plugins/openapi/pkg/app/cli.go index 8519fa4f..abfe6661 100644 --- a/plugins/openapi/pkg/app/cli.go +++ b/plugins/openapi/pkg/app/cli.go @@ -26,5 +26,6 @@ func (cmd *VersionCmd) Run() error { AppName, version.GetInfo().String(), ) + return nil } diff --git a/plugins/openapi/pkg/app/exec.go b/plugins/openapi/pkg/app/exec.go index 448d2660..e8f9e1f4 100644 --- a/plugins/openapi/pkg/app/exec.go +++ b/plugins/openapi/pkg/app/exec.go @@ -23,12 +23,15 @@ func (cmd *ExecCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithDirectory(cmd.Directory).WithURL(cmd.URL).WithIDFormat(cmd.IDFormat).WithServiceName(cmd.ServiceName) templateContent, err := cmd.getTemplateContent() if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) + return exec.Execute(ctx.Context, ctx.Log, transformer, fetcher) } diff --git a/plugins/openapi/pkg/app/export_transform.go b/plugins/openapi/pkg/app/export_transform.go index e1de22ca..3c28331c 100644 --- a/plugins/openapi/pkg/app/export_transform.go +++ b/plugins/openapi/pkg/app/export_transform.go @@ -14,6 +14,7 @@ func (t *ExportTransformCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + transformer := transform.NewGoTemplateTransform(templateContent) return transformer.ExportTransform(os.Stdout) diff --git a/plugins/openapi/pkg/app/fetch.go b/plugins/openapi/pkg/app/fetch.go index e168dd22..f9e2b3d4 100644 --- a/plugins/openapi/pkg/app/fetch.go +++ b/plugins/openapi/pkg/app/fetch.go @@ -25,6 +25,7 @@ func (cmd *FetchCmd) Run(ctx *cc.CommonCtx) error { if err != nil { return err } + fetcher = fetcher.WithDirectory(cmd.Directory).WithURL(cmd.URL).WithIDFormat(cmd.IDFormat).WithServiceName(cmd.ServiceName) return fetcher.Fetch(ctx.Context, os.Stdout, os.Stderr) diff --git a/plugins/openapi/pkg/app/transform.go b/plugins/openapi/pkg/app/transform.go index 2cd7f116..8e7501da 100644 --- a/plugins/openapi/pkg/app/transform.go +++ b/plugins/openapi/pkg/app/transform.go @@ -21,6 +21,7 @@ func (t *TransformCmd) Run(ctx *cc.CommonCtx) error { } goTemplateTransformer := transform.NewGoTemplateTransform(template) + return t.transform(ctx.Context, goTemplateTransformer) } @@ -29,8 +30,11 @@ func (t *TransformCmd) transform(ctx context.Context, transformer plugin.Transfo } func (t *TransformCmd) getTemplateContent() ([]byte, error) { - var templateContent []byte - var err error + var ( + templateContent []byte + err error + ) + if t.Template == "" { templateContent, err = Assets().ReadFile("assets/transform_template.tmpl") if err != nil { @@ -42,5 +46,6 @@ func (t *TransformCmd) getTemplateContent() ([]byte, error) { return nil, err } } + return templateContent, nil } diff --git a/plugins/openapi/pkg/fetch/fetch.go b/plugins/openapi/pkg/fetch/fetch.go index 26c7e00d..7d189fa7 100644 --- a/plugins/openapi/pkg/fetch/fetch.go +++ b/plugins/openapi/pkg/fetch/fetch.go @@ -52,8 +52,7 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer } for _, service := range services { - err = writer.Write(service) - if err != nil { + if err := writer.Write(service); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } @@ -64,8 +63,7 @@ func (f *Fetcher) Fetch(ctx context.Context, outputWriter, errorWriter io.Writer } for _, api := range apis { - err = writer.Write(api) - if err != nil { + if err := writer.Write(api); err != nil { _, _ = errorWriter.Write([]byte(err.Error())) } } diff --git a/plugins/openapi/pkg/openapi/openapi.go b/plugins/openapi/pkg/openapi/openapi.go index 73357ec2..86c059d4 100644 --- a/plugins/openapi/pkg/openapi/openapi.go +++ b/plugins/openapi/pkg/openapi/openapi.go @@ -47,16 +47,20 @@ func New(directory, specURL, idFormat, serviceName string) (*Client, error) { if err != nil { return nil, errors.Wrapf(err, "url not parsed: %s", specURL) } + doc, err := openapi3.NewLoader().LoadFromURI(parsedURL) if err != nil { return nil, errors.Wrapf(err, "cannot load OpenAPI spec from URL : %s", specURL) } + if serviceName != "" { if doc.Info.Extensions == nil { doc.Info.Extensions = make(map[string]interface{}, 0) } + doc.Info.Extensions["ServiceName"] = canonicalizeServiceName(serviceName, Canonical) } + c.docs = append(c.docs, doc) } @@ -64,17 +68,21 @@ func New(directory, specURL, idFormat, serviceName string) (*Client, error) { if _, err := os.Stat(directory); errors.Is(err, os.ErrNotExist) { return nil, errors.Wrapf(err, "directory not found: %s", directory) } + files, err := os.ReadDir(directory) if err != nil { return nil, errors.Wrapf(err, "cannot read directory: %s", directory) } + for _, file := range files { if !file.IsDir() { filename := fmt.Sprintf("%s/%s", directory, file.Name()) + doc, err := openapi3.NewLoader().LoadFromFile(filename) if err != nil { return nil, errors.Wrapf(err, "cannot open file: %s", file.Name()) } + c.docs = append(c.docs, doc) } } @@ -85,54 +93,65 @@ func New(directory, specURL, idFormat, serviceName string) (*Client, error) { func (c *Client) ListServices() ([]Service, error) { services := make([]Service, 0) + for _, service := range c.docs { id := "" if service.Info.Extensions["ServiceName"] != "" { id = service.Info.Extensions["ServiceName"].(string) } + svc := newService(service.Info.Title, id, c.idFormat) services = append(services, *svc) } + return services, nil } func (c *Client) ListAPIs() ([]API, error) { apis := make([]API, 0) + for _, service := range c.docs { apiList := c.ListAPIsInService(service, c.idFormat) apis = append(apis, apiList...) } + return apis, nil } func (c *Client) ListAPIsInService(service *openapi3.T, idFormat string) []API { apis := make([]API, 0) + serviceID := service.Info.Extensions["ServiceName"].(string) if serviceID == "" { serviceID = service.Info.Title } - for pathKey, pathItem := range service.Paths.Map() { + for pathKey, pathItem := range service.Paths.Map() { if pathItem.Get != nil { api := newAPI(serviceID, service.Info.Title, "GET", pathKey, idFormat) apis = append(apis, *api) } + if pathItem.Post != nil { api := newAPI(serviceID, service.Info.Title, "POST", pathKey, idFormat) apis = append(apis, *api) } + if pathItem.Put != nil { api := newAPI(serviceID, service.Info.Title, "PUT", pathKey, idFormat) apis = append(apis, *api) } + if pathItem.Patch != nil { api := newAPI(serviceID, service.Info.Title, "PATCH", pathKey, idFormat) apis = append(apis, *api) } + if pathItem.Delete != nil { api := newAPI(serviceID, service.Info.Title, "DELETE", pathKey, idFormat) apis = append(apis, *api) } + if pathItem.Options != nil { api := newAPI(serviceID, service.Info.Title, "OPTIONS", pathKey, idFormat) apis = append(apis, *api) @@ -147,9 +166,11 @@ func newService(name, id, idFormat string) *Service { service.DisplayName = name service.Type = "service" service.ID = id + if id == "" { service.ID = canonicalizeServiceName(name, idFormat) } + return service } @@ -162,6 +183,7 @@ func newAPI(serviceID, serviceName, method, path, idFormat string) *API { api.Path = path api.DisplayName = fmt.Sprintf("%s %s", method, path) api.ID = canonicalizeEndpoint(api.ServiceID, method, path, idFormat) + return api } @@ -172,6 +194,7 @@ func canonicalizePath(uri string) string { func canonicalizeEndpoint(service, method, path, idFormat string) string { parts := []string{service, method} + switch idFormat { case Base64: parts = append(parts, path) diff --git a/sdk/common/cc/cc.go b/sdk/common/cc/cc.go index bd659af8..4e843f89 100644 --- a/sdk/common/cc/cc.go +++ b/sdk/common/cc/cc.go @@ -22,6 +22,7 @@ func NewCommonContext(verbosity int, config string) *CommonCtx { Prod: false, LogLevelParsed: logLevelParsed, } + newLogger, err := logger.NewLogger(os.Stdout, os.Stderr, logCfg) if err != nil { log.Fatalf("failed to initialize logger: %s", err.Error()) diff --git a/sdk/common/exit_code.go b/sdk/common/exit_code.go index b1d058ff..efc92c2a 100644 --- a/sdk/common/exit_code.go +++ b/sdk/common/exit_code.go @@ -1,11 +1,12 @@ package common -import "sync/atomic" - -var ( - exitCode int32 +import ( + "io" + "sync/atomic" ) +var exitCode int32 + func GetExitCode() int { return int(atomic.LoadInt32(&exitCode)) } @@ -13,3 +14,9 @@ func GetExitCode() int { func SetExitCode(code int) { atomic.StoreInt32(&exitCode, int32(code)) //nolint:gosec } + +func WriteErrorWithExitCode(w io.Writer, err error, code int) { + _, _ = w.Write([]byte(err.Error())) + + SetExitCode(code) +} diff --git a/sdk/common/js/reader.go b/sdk/common/js/reader.go index 60fa4712..d1d7cb48 100644 --- a/sdk/common/js/reader.go +++ b/sdk/common/js/reader.go @@ -41,6 +41,7 @@ func NewJSONArrayReader(r io.Reader) (*JSONArrayReader, error) { // returns io.EOF at the end of the input stream. func (r *JSONArrayReader) ReadProtoMessage(message proto.Message) error { more, err := r.more() + switch { case err != nil: return err @@ -56,11 +57,13 @@ func (r *JSONArrayReader) Read(message any) error { if err != nil { return err } + if more { if err := r.decoder.Decode(&message); err != nil { return err } } + return nil } @@ -74,6 +77,7 @@ func (r *JSONArrayReader) more() (bool, error) { if err != nil { return false, err } + if delim, ok := tok.(json.Delim); !ok && delim.String() != "]" { return false, errors.Errorf("file does not contain a JSON array") } diff --git a/sdk/common/js/writer.go b/sdk/common/js/writer.go index edf7f4d4..124ad27b 100644 --- a/sdk/common/js/writer.go +++ b/sdk/common/js/writer.go @@ -23,34 +23,40 @@ func NewJSONArrayWriter(w io.Writer) *JSONArrayWriter { } func (w *JSONArrayWriter) WriteProtoMessage(message protoreflect.ProtoMessage) error { - err := w.writeDelimiters() - if err != nil { + if err := w.writeDelimiters(); err != nil { return err } + jsonObj, err := protojson.Marshal(message) if err != nil { return err } + _, err = w.writer.Write(jsonObj) + if !w.addDelimiter { w.addDelimiter = true } + return err } func (w *JSONArrayWriter) Write(message any) error { - err := w.writeDelimiters() - if err != nil { + if err := w.writeDelimiters(); err != nil { return err } + jsonObj, err := json.Marshal(message) if err != nil { return err } + _, err = w.writer.Write(jsonObj) + if !w.addDelimiter { w.addDelimiter = true } + return err } @@ -67,9 +73,11 @@ func (w *JSONArrayWriter) Close() error { if err != nil { return err } + w.addDelimiter = false w.writer = nil } + return nil } @@ -79,13 +87,16 @@ func (w *JSONArrayWriter) writeDelimiters() error { if err != nil { return err } + w.arrayInitialized = true } + if w.addDelimiter { _, err := w.writer.Write([]byte{','}) if err != nil { return err } } + return nil } diff --git a/sdk/common/kongyaml/kongyaml.go b/sdk/common/kongyaml/kongyaml.go index 6c9a755d..07df3790 100644 --- a/sdk/common/kongyaml/kongyaml.go +++ b/sdk/common/kongyaml/kongyaml.go @@ -22,14 +22,15 @@ func NewYAMLResolver(yamlKey string) *YAMLResolver { func (y *YAMLResolver) Loader(r io.Reader) (kong.Resolver, error) { decoder := yaml.NewDecoder(r) config := map[interface{}]interface{}{} - err := decoder.Decode(config) - if err != nil { + + if err := decoder.Decode(config); err != nil { return nil, err } if y.yamlKey != "" { var ok bool config, ok = config[y.yamlKey].(map[interface{}]interface{}) + if !ok { return kong.ResolverFunc(func(context *kong.Context, parent *kong.Path, flag *kong.Flag) (interface{}, error) { return nil, nil @@ -43,14 +44,17 @@ func (y *YAMLResolver) Loader(r io.Reader) (kong.Resolver, error) { path = append(path, flag.Name) path = strings.Split(strings.Join(path, "-"), "-") s := find(config, path) + if s == nil { fullPath := []string{} for n := parent.Node(); n != nil && n.Type != kong.ApplicationNode; n = n.Parent { fullPath = append([]string{n.Name}, fullPath...) } + fullPath = append(fullPath, path...) s = find(config, fullPath) } + return s, nil }), nil } @@ -62,5 +66,6 @@ func find(config map[interface{}]interface{}, path []string) interface{} { return find(child, path[i+1:]) } } + return config[strings.Join(path, "-")] } diff --git a/sdk/exec/exec.go b/sdk/exec/exec.go index b950fa84..0a2eb506 100644 --- a/sdk/exec/exec.go +++ b/sdk/exec/exec.go @@ -18,6 +18,7 @@ func Execute(ctx context.Context, log *zerolog.Logger, transformer plugin.Transf if err != nil { log.Printf("Could not fetch data %s", err.Error()) } + pipeWriter.Close() }() diff --git a/sdk/plugin/plugin.go b/sdk/plugin/plugin.go index ad8bb5e2..f5b6f0f5 100644 --- a/sdk/plugin/plugin.go +++ b/sdk/plugin/plugin.go @@ -58,7 +58,9 @@ func NewDSPlugin(options ...PluginOption) *DSPlugin { // json encodes results and prints to plugin writer. func (plugin *DSPlugin) WriteFetchOutput(results chan map[string]interface{}, errCh chan error) error { var wg sync.WaitGroup + wg.Add(1) + go func() { for err := range errCh { _, wErr := plugin.errWriter.Write([]byte(err.Error() + "\n")) @@ -66,22 +68,27 @@ func (plugin *DSPlugin) WriteFetchOutput(results chan map[string]interface{}, er log.Fatalf("cannot write to output: %s", wErr.Error()) } } + wg.Done() }() wg.Add(1) + go func() { writer := js.NewJSONArrayWriter(plugin.outWriter) defer writer.Close() + for result := range results { err := writer.Write(result) if err != nil { log.Printf("Could not write result [%s] to output", result) } } + wg.Done() }() wg.Wait() + return nil } diff --git a/sdk/transform/functions.go b/sdk/transform/functions.go index f5f2e8c4..2984a8f0 100644 --- a/sdk/transform/functions.go +++ b/sdk/transform/functions.go @@ -46,11 +46,13 @@ func customFunctions() template.FuncMap { func separator(s string) func() string { i := -1 + return func() string { i++ if i == 0 { return "" } + return s } } @@ -63,6 +65,7 @@ func marshal(v interface{}) string { func fromEnv(key, envName string) string { value := os.Getenv(envName) strValue, _ := json.Marshal(value) + return fmt.Sprintf("%q:%s", key, string(strValue)) } diff --git a/sdk/transform/transform.go b/sdk/transform/transform.go index 9d0678fd..b8799b0d 100644 --- a/sdk/transform/transform.go +++ b/sdk/transform/transform.go @@ -37,6 +37,7 @@ func (t *GoTemplateTransform) ExportTransform(outputWriter io.Writer) error { func (t *GoTemplateTransform) Transform(ctx context.Context, ioReader io.Reader, outputWriter, errorWriter io.Writer) error { jsonWriter := js.NewJSONArrayWriter(outputWriter) defer jsonWriter.Close() + reader, err := js.NewJSONArrayReader(ioReader) if err != nil { return err @@ -44,15 +45,17 @@ func (t *GoTemplateTransform) Transform(ctx context.Context, ioReader io.Reader, for { var idpData map[string]interface{} + err := reader.Read(&idpData) if err == io.EOF { break } + if err != nil { return errors.Wrap(err, "failed to read idpData into map[string]interface{}") } - err = t.doTransform(idpData, jsonWriter) - if err != nil { + + if err := t.doTransform(idpData, jsonWriter); err != nil { return err } } @@ -69,6 +72,7 @@ func (t *GoTemplateTransform) doTransform(idpData map[string]interface{}, jsonWr if err := jsonWriter.WriteProtoMessage(dirV3msg); err != nil { return errors.Wrap(err, "failed to write directory objects to output") } + return nil } @@ -77,9 +81,11 @@ func (t *GoTemplateTransform) TransformObject(idpData map[string]interface{}) (* if err != nil { return nil, errors.Wrap(err, "GoTemplateTransform transformTemplate execute failed") } + if os.Getenv("DEBUG") != "" { os.Stdout.WriteString(output) } + var dirV3msg msg.Transform opts := protojson.UnmarshalOptions{ @@ -87,8 +93,7 @@ func (t *GoTemplateTransform) TransformObject(idpData map[string]interface{}) (* DiscardUnknown: false, } - err = opts.Unmarshal([]byte(output), &dirV3msg) - if err != nil { + if err := opts.Unmarshal([]byte(output), &dirV3msg); err != nil { return nil, errors.Wrap(err, "failed to unmarshal transformed data into directory v3 objects and relations") } @@ -97,13 +102,15 @@ func (t *GoTemplateTransform) TransformObject(idpData map[string]interface{}) (* func (t *GoTemplateTransform) transformToTemplate(input map[string]interface{}, templateString string) (string, error) { temp := template.New("GoTemplateTransform") + parsed, err := temp.Funcs(customFunctions()).Parse(templateString) if err != nil { return "", err } + var filled bytes.Buffer - err = parsed.Execute(&filled, input) - if err != nil { + + if err := parsed.Execute(&filled, input); err != nil { return "", err } diff --git a/sdk/transform/transform_test.go b/sdk/transform/transform_test.go index ede2aa53..781ff5d1 100644 --- a/sdk/transform/transform_test.go +++ b/sdk/transform/transform_test.go @@ -18,8 +18,10 @@ import ( func TestTransform(t *testing.T) { // Arrange content, err := sdk.Assets().ReadFile("assets/peoplefinder.json") - contentReader := strings.NewReader(string(content)) assert.NoError(t, err) + + contentReader := strings.NewReader(string(content)) + template, err := sdk.Assets().ReadFile("assets/test_template.tmpl") assert.NoError(t, err) @@ -36,7 +38,7 @@ func TestTransform(t *testing.T) { // Assert bufLen := transformBuffer.Len() - var transformOutput = make([]byte, bufLen) + transformOutput := make([]byte, bufLen) reader := bufio.NewReader(&transformBuffer) _, err = reader.Read(transformOutput) @@ -95,7 +97,9 @@ func TestTransformEscapedChars(t *testing.T) { assert.NoError(t, err) contentReader := strings.NewReader(string(content)) + var transformBuffer bytes.Buffer + writer := bufio.NewWriter(&transformBuffer) transformer := transform.NewGoTemplateTransform(transformTemplate) @@ -108,7 +112,7 @@ func TestTransformEscapedChars(t *testing.T) { // Assert bufLen := transformBuffer.Len() - var transformOutput = make([]byte, bufLen) + transformOutput := make([]byte, bufLen) reader := bufio.NewReader(&transformBuffer) _, err = reader.Read(transformOutput) assert.NoError(t, err) @@ -124,6 +128,7 @@ func TestTransformEscapedChars(t *testing.T) { objectCount := len(directoryObject.Objects) assert.Equal(t, objectCount, 2) + relationCount := len(directoryObject.Relations) assert.Equal(t, relationCount, 2)