Skip to content

Commit da0dee7

Browse files
authored
Merge pull request #15 from yaklabco/feat-cycle-detection
2 parents 3aa9175 + 2e04c16 commit da0dee7

File tree

17 files changed

+473
-38
lines changed

17 files changed

+473
-38
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
### Added
1313

14+
- `--dryrun` mode.
15+
16+
- Automated detection of circular dependencies among stavefile targets.
17+
1418
- `CHANGELOG.md`! (And first formally-versioned release of Stave.)
1519

1620
[unreleased]: https://github.com/yaklabco/stave/compare/v0.1.0...HEAD

TODO.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
- [x] Parallelize tests w/locking mechanism to prevent parallel tests in same testdata subdir
2727
- [x] Bring in CI from `goctx` with the fancy caching
2828
- [x] Validate that every target is run in its own Goroutine
29+
- [x] Dependency cycle detection

pkg/st/cycle_detector.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package st
2+
3+
import (
4+
"log/slog"
5+
"runtime"
6+
"strings"
7+
"sync"
8+
9+
"github.com/samber/lo"
10+
"github.com/yaklabco/stave/pkg/toposort"
11+
)
12+
13+
const (
14+
// maxStackDepthToCheck defines the maximum stack depth for runtime caller inspection.
15+
maxStackDepthToCheck = 64
16+
)
17+
18+
var (
19+
depsByID = make(map[string]toposort.TopoSortable) //nolint:gochecknoglobals // Part of a mutexed pattern.
20+
depsByIDMutex sync.RWMutex //nolint:gochecknoglobals // Part of a mutexed pattern.
21+
)
22+
23+
func firstExternalCaller() *runtime.Frame {
24+
thisProgCtr, _, _, ok := runtime.Caller(0)
25+
if !ok {
26+
return nil
27+
}
28+
thisFunc := runtime.FuncForPC(thisProgCtr)
29+
pkgPrefix := getPackagePath(thisFunc)
30+
31+
// runtime.Callers (0), firstExternalCaller (1), the function calling firstExternalCaller (2)
32+
const skip = 2
33+
progCtrsAboveUs := make([]uintptr, maxStackDepthToCheck)
34+
nProgCtrsAboveUs := runtime.Callers(skip, progCtrsAboveUs)
35+
frames := runtime.CallersFrames(progCtrsAboveUs[:nProgCtrsAboveUs])
36+
37+
for {
38+
frame, more := frames.Next()
39+
40+
// frame.Function is the fully-qualified name:
41+
// "mypkg.myFunc"
42+
// "otherpkg.DoThing"
43+
// "github.com/me/foo/bar.Baz"
44+
slog.Debug("checking the frame", slog.String("function", frame.Function), slog.String("pkg_prefix", pkgPrefix))
45+
if !strings.HasPrefix(frame.Function, pkgPrefix) {
46+
return &frame
47+
}
48+
49+
if !more {
50+
break
51+
}
52+
}
53+
54+
return nil
55+
}
56+
57+
func getPackagePath(thisFunc *runtime.Func) string {
58+
pkgPrefix := thisFunc.Name()
59+
lastSlash := strings.LastIndex(pkgPrefix, "/")
60+
lastDot := strings.LastIndex(pkgPrefix, ".")
61+
if lastDot > lastSlash {
62+
pkgPrefix = pkgPrefix[:lastDot] // e.g. "github.com/me/project/mypkg"
63+
}
64+
return pkgPrefix
65+
}
66+
67+
func checkForCycle(funcs []Fn) error {
68+
callerFrame := firstExternalCaller()
69+
if callerFrame == nil {
70+
slog.Warn("could not determine caller, skipping circular-dependency check")
71+
return nil
72+
}
73+
74+
callerID := callerFrame.File + ":" + callerFrame.Function
75+
slog.Debug("checking for cycle", slog.String("caller_id", callerID))
76+
77+
funcIDs := make([]string, 0, len(funcs))
78+
for _, theFunc := range funcs {
79+
theFuncObj := theFunc.Underlying()
80+
theFile, _ := theFuncObj.FileLine(0)
81+
theFuncID := theFile + ":" + theFuncObj.Name()
82+
slog.Debug("adding dependency", slog.String("func_id", theFuncID))
83+
funcIDs = append(funcIDs, theFuncID)
84+
}
85+
86+
depsByIDMutex.Lock()
87+
defer depsByIDMutex.Unlock()
88+
depsByID[callerID] = depsNode{tpID: callerID, dependencyTPIDs: funcIDs}
89+
90+
_, err := toposort.Sort(lo.Values(depsByID), true)
91+
92+
return err
93+
}
94+
95+
type depsNode struct {
96+
tpID string
97+
dependencyTPIDs []string
98+
}
99+
100+
func (n depsNode) TPID() string { return n.tpID }
101+
func (n depsNode) DependencyTPIDs() []string { return n.dependencyTPIDs }

pkg/st/deps.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ func checkFns(fns []interface{}) []Fn {
137137

138138
funcs[iFunc] = F(theFunc)
139139
}
140+
141+
if err := checkForCycle(funcs); err != nil {
142+
panic(fmt.Errorf("checking for cycles in dependency graph: %w", err))
143+
}
144+
140145
return funcs
141146
}
142147

@@ -175,7 +180,11 @@ func changeExit(oldExitCode, newExitCode int) int {
175180

176181
// funcName returns the unique name for the function.
177182
func funcName(i interface{}) string {
178-
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
183+
return funcObj(i).Name()
184+
}
185+
186+
func funcObj(i interface{}) *runtime.Func {
187+
return runtime.FuncForPC(reflect.ValueOf(i).Pointer())
179188
}
180189

181190
func displayName(name string) string {

pkg/st/fn.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"fmt"
88
"reflect"
9+
"runtime"
910
)
1011

1112
// Fn represents a function that can be run with st.Deps. Package, Name, and ID must combine to
@@ -23,6 +24,9 @@ type Fn interface {
2324

2425
// Run should run the function.
2526
Run(ctx context.Context) error
27+
28+
// Underlying should return the original, wrapped function object.
29+
Underlying() *runtime.Func
2630
}
2731

2832
// F takes a function that is compatible as a stave target, and any args that need to be passed to
@@ -78,13 +82,15 @@ func F(target interface{}, args ...interface{}) Fn {
7882
}
7983
return nil
8084
},
85+
underlying: funcObj(target),
8186
}
8287
}
8388

8489
type fn struct {
85-
name string
86-
id string
87-
f func(ctx context.Context) error
90+
name string
91+
id string
92+
f func(ctx context.Context) error
93+
underlying *runtime.Func
8894
}
8995

9096
// Name returns the fully qualified name of the function.
@@ -102,6 +108,11 @@ func (f fn) Run(ctx context.Context) error {
102108
return f.f(ctx)
103109
}
104110

111+
// Underlying returns the original, wrapped function object.
112+
func (f fn) Underlying() *runtime.Func {
113+
return f.underlying
114+
}
115+
105116
func checkF(target interface{}, args []interface{}) (bool, bool, error) {
106117
theType := reflect.TypeOf(target)
107118
if theType == nil || theType.Kind() != reflect.Func {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package stave
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
// TestCyclicDependencyDetection verifies proper detection of cyclic dependencies.
12+
func TestCyclicDependencyDetection(t *testing.T) {
13+
t.Parallel()
14+
testDataDir := "./testdata/cyclic_dependencies"
15+
mu := mutexByDir(testDataDir)
16+
mu.Lock()
17+
defer mu.Unlock()
18+
19+
ctx := t.Context()
20+
21+
stdout := &bytes.Buffer{}
22+
stderr := &bytes.Buffer{}
23+
24+
err := Run(RunParams{
25+
BaseCtx: ctx,
26+
Dir: testDataDir,
27+
Stdout: stdout,
28+
Stderr: stderr,
29+
Args: []string{"Step1"},
30+
})
31+
require.Error(t, err)
32+
33+
expected := "circular dependency detected"
34+
assert.Contains(t, stderr.String(), expected)
35+
}

pkg/stave/main.go

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/yaklabco/stave/internal/parse"
3131
"github.com/yaklabco/stave/pkg/sh"
3232
"github.com/yaklabco/stave/pkg/st"
33+
"github.com/yaklabco/stave/pkg/stave/prettylog"
3334
)
3435

3536
const (
@@ -84,7 +85,10 @@ func (i RunParams) UsesStavefiles() bool {
8485
// function to allow it to be used from other programs, specifically so you can
8586
// go run a simple file that run's stave's Run.
8687
func Run(params RunParams) error {
87-
logHandler := setupLogger(params)
88+
if params.WriterForLogger == nil {
89+
params.WriterForLogger = params.Stderr
90+
}
91+
logHandler := prettylog.SetupPrettyLogger(params.WriterForLogger)
8892

8993
if params.Debug {
9094
logHandler.SetLevel(cblog.DebugLevel)
@@ -124,24 +128,6 @@ func Run(params RunParams) error {
124128
return stave(ctx, params)
125129
}
126130

127-
func setupLogger(params RunParams) *cblog.Logger {
128-
if params.WriterForLogger == nil {
129-
params.WriterForLogger = params.Stderr
130-
}
131-
132-
logHandler := cblog.NewWithOptions(
133-
params.WriterForLogger,
134-
cblog.Options{
135-
Level: cblog.WarnLevel, // Setting this to lowest possible value, since slog will handle the actual filtering.
136-
ReportTimestamp: true,
137-
ReportCaller: true,
138-
},
139-
)
140-
logger := slog.New(logHandler)
141-
slog.SetDefault(logger)
142-
return logHandler
143-
}
144-
145131
func stave(ctx context.Context, params RunParams) error {
146132
files, err := Stavefiles(params.Dir, params.GOOS, params.GOARCH, params.UsesStavefiles())
147133
if err != nil {

pkg/stave/main_test.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,10 @@ func TestVerbose(t *testing.T) {
601601
err = Run(runParams)
602602
require.NoError(t, err, "stderr was: %s", stderr.String())
603603

604-
expectedRegexp := `Running target: TestVerbose\n.*hi!\n`
605-
assert.Regexp(t, expectedRegexp, stderr.String())
604+
expectedOutRegexp := `\bhi!\n`
605+
assert.Regexp(t, expectedOutRegexp, stdout.String())
606+
expectedErrRegexp := `\bRunning target: TestVerbose\b`
607+
assert.Regexp(t, expectedErrRegexp, stderr.String())
606608
}
607609

608610
func TestList(t *testing.T) {
@@ -1021,11 +1023,13 @@ func TestMultipleTargets(t *testing.T) {
10211023

10221024
err := Run(runParams)
10231025
require.NoError(t, err, "stderr was: %s", stderr.String())
1024-
expectedRegexp := `Running target: TestVerbose\n.*hi!\nRunning target: ReturnsNilError\n`
1025-
assert.Regexp(t, expectedRegexp, stderr.String())
1026+
expectedOutRegexp := `\bhi!`
1027+
assert.Regexp(t, expectedOutRegexp, stdout.String())
1028+
expectedErrRegexp := `Running target: TestVerbose\n(.*\n)*Running target: ReturnsNilError\n`
1029+
assert.Regexp(t, expectedErrRegexp, stderr.String())
10261030

10271031
expectedOutStr := "stuff\n"
1028-
assert.Equal(t, expectedOutStr, stdout.String())
1032+
assert.Contains(t, stdout.String(), expectedOutStr)
10291033
}
10301034

10311035
func TestFirstTargetFails(t *testing.T) {
@@ -1364,7 +1368,7 @@ func TestCompiledFlags(t *testing.T) {
13641368
err = run(stdout, stderr, name, "-v", "testverbose")
13651369
require.NoError(t, err, "stderr was: %s", stderr.String())
13661370
want = hiExclam
1367-
assert.Contains(t, stderr.String(), want)
1371+
assert.Contains(t, stdout.String(), want)
13681372

13691373
// pass list flag -l
13701374
err = run(stdout, stderr, name, "-l")
@@ -1441,7 +1445,7 @@ func TestCompiledEnvironmentVars(t *testing.T) {
14411445
err = run(stdout, stderr, name, st.VerboseEnv+"=1", "testverbose")
14421446
require.NoError(t, err, "stderr was: %s", stderr.String())
14431447
want = hiExclam
1444-
assert.Contains(t, stderr.String(), want)
1448+
assert.Contains(t, stdout.String(), want)
14451449

14461450
err = run(stdout, stderr, name, "STAVEFILE_LIST=1")
14471451
require.NoError(t, err, "stderr was: %s", stderr.String())

pkg/stave/prettylog/prettylog.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package prettylog
2+
3+
import (
4+
"io"
5+
"log/slog"
6+
7+
"github.com/charmbracelet/log"
8+
)
9+
10+
func SetupPrettyLogger(writerForLogger io.Writer) *log.Logger {
11+
logHandler := log.NewWithOptions(
12+
writerForLogger,
13+
log.Options{
14+
Level: log.InfoLevel, // Setting this to lowest possible value, since slog will handle the actual filtering.
15+
ReportTimestamp: true,
16+
ReportCaller: true,
17+
},
18+
)
19+
logger := slog.New(logHandler)
20+
slog.SetDefault(logger)
21+
22+
return logHandler
23+
}

pkg/stave/stavefile_embedder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@ package stave
22

33
import _ "embed"
44

5-
//go:embed templates/stavefile_tmpl.go
5+
//go:embed templates/initial_stavefile.go
66
var staveTpl string

0 commit comments

Comments
 (0)