diff --git a/cmd/dev/app/rule_type/rttst.go b/cmd/dev/app/rule_type/rttst.go index 7c263bb956..2236414313 100644 --- a/cmd/dev/app/rule_type/rttst.go +++ b/cmd/dev/app/rule_type/rttst.go @@ -67,6 +67,7 @@ func CmdTest() *cobra.Command { testCmd.Flags().StringP("token", "t", "", "token to authenticate to the provider."+ "Can also be set via the TEST_AUTH_TOKEN environment variable.") testCmd.Flags().StringArrayP("data-source", "d", []string{}, "YAML file containing the data source to test the rule with") + testCmd.Flags().BoolP("debug", "", false, "Start REGO debugger (only works for REGO-based rules types)") if err := testCmd.MarkFlagRequired("rule-type"); err != nil { fmt.Fprintf(os.Stderr, "Error marking flag as required: %s\n", err) @@ -98,6 +99,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { token := viper.GetString("test.auth.token") providerclass := cmd.Flag("provider") providerconfig := cmd.Flag("provider-config") + debug := cmd.Flag("debug").Value.String() == "true" dataSourceFileStrings, err := cmd.Flags().GetStringArray("data-source") if err != nil { @@ -197,7 +199,10 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { // TODO: use cobra context here ctx := context.Background() - eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil /*experiments*/, options.WithDataSources(dsRegistry)) + eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil, /*experiments*/ + options.WithDataSources(dsRegistry), + options.WithDebugger(debug), + ) if err != nil { return fmt.Errorf("cannot create rule type engine: %w", err) } diff --git a/internal/engine/eval/rego/debug.go b/internal/engine/eval/rego/debug.go new file mode 100644 index 0000000000..0c2a0a8c3d --- /dev/null +++ b/internal/engine/eval/rego/debug.go @@ -0,0 +1,670 @@ +// SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +// Package rego provides the rego rule evaluator +package rego + +import ( + "bufio" + "context" + "errors" + "fmt" + "math" + "os" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/open-policy-agent/opa/ast/location" + "github.com/open-policy-agent/opa/debug" + "github.com/open-policy-agent/opa/rego" + + "github.com/mindersec/minder/internal/util/cli" + "github.com/mindersec/minder/pkg/engine/v1/interfaces" +) + +type eventHandler struct { + ch chan *debug.Event +} + +func newEventHandler() *eventHandler { + return &eventHandler{ + ch: make(chan *debug.Event), + } +} + +func (eh *eventHandler) HandleEvent(event debug.Event) { + eh.ch <- &event +} + +func (eh *eventHandler) WaitFor( + ctx context.Context, + eventTypes ...debug.EventType, +) *debug.Event { + for { + select { + case e := <-eh.ch: + if slices.Contains(eventTypes, e.Type) { + return e + } + case <-ctx.Done(): + return nil + } + } +} + +var ( + errInvalidInstr = errors.New("invalid instruction") + errInvalidBP = errors.New("invalid breakpoint") +) + +func (e *Evaluator) Debug( + ctx context.Context, + _ *interfaces.Result, + input *Input, + funcs ...func(*rego.Rego), +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ds := newDebugSession( + e.cfg.Def, + input, + e.reseval.getQueryString(), + e.regoOpts, + funcs, + ) + + return ds.Start(ctx) +} + +type debugSession struct { + src string + lines int + input *Input + query string + opts []func(*rego.Rego) + funcs []func(*rego.Rego) + + // fields initialized after starting the session + session debug.Session + eh *eventHandler +} + +func newDebugSession( + src string, + input *Input, + query string, + opts []func(*rego.Rego), + funcs []func(*rego.Rego), +) *debugSession { + ds := &debugSession{} + ds.src = src + ds.lines = len(strings.Split(src, "\n")) + ds.input = input + ds.query = query + ds.opts = opts + ds.funcs = funcs + + return ds +} + +func (ds *debugSession) startDebugger( + ctx context.Context, +) error { + eh := newEventHandler() + debugger := debug.NewDebugger( + debug.SetEventHandler(eh.HandleEvent), + ) + launchProps := debug.LaunchEvalProperties{ + LaunchProperties: debug.LaunchProperties{ + StopOnEntry: false, + StopOnFail: false, + StopOnResult: true, + EnablePrint: true, + RuleIndexing: false, + }, + Input: ds.input, + Query: ds.query, + } + + regoOpts := make([]debug.LaunchOption, 0, len(ds.opts)+len(ds.funcs)) + for _, f := range ds.opts { + regoOpts = append(regoOpts, debug.RegoOption(f)) + } + for _, f := range ds.funcs { + regoOpts = append(regoOpts, debug.RegoOption(f)) + } + + session, err := debugger.LaunchEval(ctx, launchProps, regoOpts...) + if err != nil { + return err + } + + ds.session = session + ds.eh = eh + + return nil +} + +func (ds *debugSession) Start(ctx context.Context) error { + err := ds.startDebugger(ctx) + if err != nil { + return fmt.Errorf("error launching debugger: %w", err) + } + + thr := debug.ThreadID(1) + fmt.Print("(mindbg) ") + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := scanner.Text() + + var b strings.Builder + switch { + case line == "": + + case line == "r": + err = ds.startDebugger(ctx) + if err != nil { + return fmt.Errorf("error restarting debugger: %w", err) + } + fmt.Fprintf(&b, "Restarted") + case line == "c": + if err := ds.session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + evt := ds.eh.WaitFor(ctx, + debug.ExceptionEventType, + debug.StoppedEventType, + debug.StdoutEventType, + ) + switch evt.Type { + case debug.ExceptionEventType: + fmt.Fprintf(&b, "\nException\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case debug.StoppedEventType: + fmt.Fprintf(&b, "\nStopped\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case debug.StdoutEventType: + fmt.Fprintf(&b, "\nFinished\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + fmt.Fprintf(&b, "\nResult: ") + err := printVar(&b, + fmt.Sprintf("%s.*", RegoQueryPrefix), + ds.session, + evt.Thread, + ) + if err != nil { + return fmt.Errorf("error printing variable: %w", err) + } + } + case line == "locals": + if err := printLocals(&b, ds.session, thr); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case line == "bp": + bps, err := ds.session.Breakpoints() + if err != nil { + return fmt.Errorf("error getting breakpoints: %w", err) + } + printBreakpoints(&b, bps) + case line == "list", line == "l": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, ds.src, stack) + case line == "trs": + threads, err := ds.session.Threads() + if err != nil { + return fmt.Errorf("error getting threads: %w", err) + } + printThreads(&b, threads) + case line == "cla", + line == "clearall": + if err := ds.session.ClearBreakpoints(); err != nil { + return fmt.Errorf("error clearing breakpoints: %w", err) + } + // "next" is a bit quirky, since it requires adding an + // internal breakpoint, running until it's reached, + // and finally removing the breakpoint. + case line == "n", + line == "next": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + if loc := getCurrentLocation(stack); loc != nil { + loc.Row += 1 // let's hope it always exists... + loc.Col = 0 + + // add internal breakpoint + bp, err := ds.session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + + // resume execution + if err := ds.session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + + // clear internal breakpoint, even if + // we stopped for another reason. + if _, err := ds.session.RemoveBreakpoint(bp.ID()); err != nil { + return fmt.Errorf("error removing breakpoing: %w", err) + } + + printStackTrace(&b, ds.src, stack) + } + case line == "s", + line == "sv": + go func() { + if err := ds.session.StepOver(thr); err != nil { + panic(err) + } + }() + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, ds.src, stack) + case line == "si": + go func() { + if err := ds.session.StepIn(thr); err != nil { + panic(err) + } + }() + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, ds.src, stack) + case line == "so": + go func() { + if err := ds.session.StepOut(thr); err != nil { + panic(err) + } + }() + evt := ds.eh.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, ds.src, stack) + case line == "q": + return fmt.Errorf("user abort") + case line == "h", + line == "help": + printHelp(&b) + case strings.HasPrefix(line, "p"): + varname, err := toVarName(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + if err := printVar(&b, fmt.Sprintf("^%s$", varname), ds.session, thr); err != nil { + return fmt.Errorf("error printing variables: %w", err) + } + case strings.HasPrefix(line, "b"): + loc, err := toLocation(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + bp, err := ds.session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + fmt.Fprintln(&b) + printBreakpoint(&b, bp) + case strings.HasPrefix(line, "cl "), + strings.HasPrefix(line, "clear "): + id, err := toBreakpointID(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + if _, err := ds.session.RemoveBreakpoint(id); err != nil { + return fmt.Errorf("error removing breakpoing: %w", err) + } + default: + fmt.Fprintf(&b, "Invalid command: %s\nPress h for help\n", line) + } + + output := b.String() + if output != "" { + fmt.Printf("%s\n(mindbg) ", output) + } else { + fmt.Printf("(mindbg) ") + } + } + + return scanner.Err() +} + +func toLocation(line string) (*location.Location, error) { + num, ok := strings.CutPrefix(line, "b ") + if !ok { + return nil, fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + i, err := strconv.ParseUint(num, 10, 64) + if err != nil { + return nil, fmt.Errorf(`%w: invalid line %s`, errInvalidBP, num) + } + return &location.Location{File: "minder.rego", Row: int(i)}, nil +} + +func toBreakpointID(line string) (debug.BreakpointID, error) { + num1, ok1 := strings.CutPrefix(line, "cl ") + num2, ok2 := strings.CutPrefix(line, "clear ") + if !ok1 && !ok2 { + return debug.BreakpointID(-1), fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + + var num string + if !ok1 { + num = num2 + } + if !ok2 { + num = num1 + } + + i, err := strconv.ParseUint(num, 10, 64) + if err != nil { + return debug.BreakpointID(-1), fmt.Errorf(`%w: invalid breakpoint id %s`, errInvalidBP, num) + } + return debug.BreakpointID(i), nil +} + +func toVarName(line string) (string, error) { + varname, ok := strings.CutPrefix(line, "p ") + if !ok { + return "", fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + return varname, nil +} + +func printBreakpoints(b *strings.Builder, bps []debug.Breakpoint) { + fmt.Fprintln(b) + for _, bp := range bps { + printBreakpoint(b, bp) + } +} + +func printBreakpoint(b *strings.Builder, bp debug.Breakpoint) { + fmt.Fprintf(b, "Breakpoint %d set at %s:%d\n", bp.ID(), bp.Location().File, bp.Location().Row) +} + +func printThreads(b *strings.Builder, threads []debug.Thread) { + fmt.Fprintln(b) + for _, thread := range threads { + fmt.Fprintf(b, "Thread %d\n", thread.ID()) + } +} + +func getCurrentLocation(stack debug.StackTrace) *location.Location { + if len(stack) == 0 { + return nil + } + + frame := stack[0] + return frame.Location() +} + +func printStackTrace(b *strings.Builder, src string, stack debug.StackTrace) { + if len(stack) == 0 { + printSource(b, src) + return + } + + lines := strings.Split(src, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + + fmt.Fprintln(b) + frame := stack[0] + if loc := frame.Location(); loc != nil { + fmt.Fprintf(b, "Frame %d at %s:%d.%d\n", + frame.ID(), + loc.File, + loc.Row, + loc.Col, + ) + + for idx, line := range strings.Split(src, "\n") { + fmt.Fprintf(b, "%*d: %s", padding, idx+1, line) + if idx+1 == loc.Row { + theline := strings.Split(string(loc.Text), "\n")[0] + fmt.Fprintf(b, "\n%s%s", + strings.Repeat(" ", loc.Col+int(padding)+2-1), + cli.SimpleBoldStyle.Render(strings.Repeat("^", len(theline))), + ) + } + fmt.Fprintln(b) + } + } +} + +func printSource(b *strings.Builder, source string) { + fmt.Fprintln(b) + lines := strings.Split(source, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + for idx, line := range lines { + fmt.Fprintf(b, "%*d: %s\n", padding, idx+1, line) + } +} + +func printLocals(b *strings.Builder, s debug.Session, thrID debug.ThreadID) error { + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + vars, err := s.Variables(scope.VariablesReference()) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), v.Value()) + } + } + + return nil +} + +func printVar( + b *strings.Builder, + varname string, + s debug.Session, + thrID debug.ThreadID, +) error { + r, err := regexp.Compile(varname) + if err != nil { + return fmt.Errorf("error instantiating regex: %w", err) + } + + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + if err := printVariablesInScope(b, r, s, scope.VariablesReference()); err != nil { + return err + } + } + + return nil +} + +func printVariablesInScope( + b *strings.Builder, + r *regexp.Regexp, + s debug.Session, + varRef debug.VarRef, +) error { + if varRef == 0 { + return nil + } + + vars, err := s.Variables(varRef) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + if r.MatchString(v.Name()) { + var b1 strings.Builder + if err := varToString(&b1, v, s, 0); err != nil { + return err + } + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), b1.String()) + + // We break early here despite the fact that + // multiple variables might match the given + // `varname`. This is done to honour lexical + // scope, showing just the only variable that + // is actually being used for evaluation in + // the given frame. + return nil + } + } + + return nil +} + +func varToString( + b *strings.Builder, + v debug.Variable, + s debug.Session, + indentation int, +) error { + padding := strings.Repeat(" ", indentation) + switch v.Type() { + case "array": + return elementsToString(b, v, s, indentation, "[", "]", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s", padding) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + case "set": + return elementsToString(b, v, s, indentation, "{", "}", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s", padding) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + case "object": + return elementsToString(b, v, s, indentation, "{", "}", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s%s: ", padding, elem.Name()) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + default: + fmt.Fprintf(b, "%s%s", padding, v.Value()) + } + + return nil +} + +func elementsToString( + b *strings.Builder, + v debug.Variable, + s debug.Session, + indentation int, + leftDelimiter string, + rightDelimiter string, + formatter func(debug.Variable) error, +) error { + padding := strings.Repeat(" ", indentation) + fmt.Fprintf(b, "%s%s\n", padding, leftDelimiter) + elems, err := s.Variables(v.VariablesReference()) + if err != nil { + return err + } + for _, elem := range elems { + if err := formatter(elem); err != nil { + return err + } + } + fmt.Fprintf(b, "%s%s", padding, rightDelimiter) + + return nil +} + +var helpMsg = ` +Available commands: + c ------------- continue + r ------------- restart debugging session + b ------- set breakpoint at line + bp ------------ show breakpoints + clear/cl - clear breakpoint with id + clearall/cla -- clear all breakpoints + trs ----------- print threads + s/sv ---------- step over + so ------------ step out + si ------------ step into + list/l -------- list source + locals -------- print local variables + q ------------- quit + help/h -------- print help +` + +func printHelp(b *strings.Builder) { + fmt.Fprintln(b, helpMsg) +} diff --git a/internal/engine/eval/rego/eval.go b/internal/engine/eval/rego/eval.go index fd2a597360..aee86cd719 100644 --- a/internal/engine/eval/rego/eval.go +++ b/internal/engine/eval/rego/eval.go @@ -44,6 +44,15 @@ type Evaluator struct { regoOpts []func(*rego.Rego) reseval resultEvaluator datasources *v1datasources.DataSourceRegistry + debug bool +} + +var _ eoptions.HasDebuggerSupport = (*Evaluator)(nil) + +// SetDebugFlag implements `HasDebuggerSupport` interface. +func (e *Evaluator) SetDebugFlag(flag bool) error { + e.debug = flag + return nil } // Input is the input for the rego evaluator @@ -132,6 +141,26 @@ func (e *Evaluator) Eval( // If the evaluator has data sources defined, expose their functions regoFuncOptions = append(regoFuncOptions, buildDataSourceOptions(res, e.datasources)...) + input := &Input{ + Profile: pol, + Ingested: obj, + OutputFormat: e.cfg.ViolationFormat, + } + enrichInputWithEntityProps(input, entity) + + if e.debug { + err := e.Debug( + ctx, + res, + input, + regoFuncOptions..., + ) + if err != nil { + return nil, fmt.Errorf("error initializing debugger: %w", err) + } + return nil, nil + } + // Create the rego object r := e.newRegoFromOptions( regoFuncOptions..., @@ -142,13 +171,6 @@ func (e *Evaluator) Eval( return nil, fmt.Errorf("could not prepare Rego: %w", err) } - input := &Input{ - Profile: pol, - Ingested: obj, - OutputFormat: e.cfg.ViolationFormat, - } - - enrichInputWithEntityProps(input, entity) rs, err := pq.Eval(ctx, rego.EvalInput(input)) if err != nil { return nil, fmt.Errorf("error evaluating profile. Might be wrong input: %w", err) diff --git a/internal/engine/eval/rego/result.go b/internal/engine/eval/rego/result.go index d713aa2557..300b903b3e 100644 --- a/internal/engine/eval/rego/result.go +++ b/internal/engine/eval/rego/result.go @@ -53,6 +53,7 @@ func (c ConstraintsViolationsFormat) String() string { } type resultEvaluator interface { + getQueryString() string getQuery() func(*rego.Rego) parseResult(rego.ResultSet, protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) } @@ -60,8 +61,12 @@ type resultEvaluator interface { type denyByDefaultEvaluator struct { } -func (*denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(RegoQueryPrefix) +func (*denyByDefaultEvaluator) getQueryString() string { + return RegoQueryPrefix +} + +func (d *denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(d.getQueryString()) } func (*denyByDefaultEvaluator) parseResult(rs rego.ResultSet, entity protoreflect.ProtoMessage, @@ -168,8 +173,12 @@ type constraintsEvaluator struct { format ConstraintsViolationsFormat } -func (*constraintsEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(fmt.Sprintf("%s.violations[details]", RegoQueryPrefix)) +func (*constraintsEvaluator) getQueryString() string { + return fmt.Sprintf("%s.violations[details]", RegoQueryPrefix) +} + +func (c *constraintsEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(c.getQueryString()) } func (c *constraintsEvaluator) parseResult(rs rego.ResultSet, _ protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) { diff --git a/internal/engine/options/options.go b/internal/engine/options/options.go index 0da6223418..b71b1eb3ab 100644 --- a/internal/engine/options/options.go +++ b/internal/engine/options/options.go @@ -35,6 +35,26 @@ func WithFlagsClient(client openfeature.IClient) Option { } } +// HasDebuggerSupport interface should be implemented by evaluation +// engines that support interactive debugger. Currently, only +// REGO-based engines should implement this. +type HasDebuggerSupport interface { + SetDebugFlag(bool) error +} + +// WithDebugger sets the evaluation engine to start an interactive +// debugging session. This MUST NOT be used in backend servers, and is +// only meant to be used in CLI tools. +func WithDebugger(flag bool) Option { + return func(e interfaces.Evaluator) error { + inner, ok := e.(HasDebuggerSupport) + if !ok { + return nil + } + return inner.SetDebugFlag(flag) + } +} + // SupportsDataSources interface advertises the fact that the implementer // can register data sources with the evaluator. type SupportsDataSources interface { diff --git a/internal/util/cli/styles.go b/internal/util/cli/styles.go index 4973c54b85..453db8c72b 100644 --- a/internal/util/cli/styles.go +++ b/internal/util/cli/styles.go @@ -27,7 +27,8 @@ var ( // Common styles var ( - CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + SimpleBoldStyle = lipgloss.NewStyle().Bold(true) ) // Banner styles