diff --git a/README.md b/README.md index 492b683..8ebc7b0 100644 --- a/README.md +++ b/README.md @@ -358,3 +358,101 @@ This package absolutely 100% could not have been written without the help of Tho Not only did the book make understanding the process of writing lexers, parsers, and asts, but it also provided the basis for the syntax of Plush itself. If you have yet to read Thorsten's book, I can't recommend it enough. Please go and buy it! + +--- + +## Render Budget + +Plush lets you attach a work-unit **budget** to any render to protect against runaway templates — deeply nested loops, recursive partials, or unexpectedly expensive helpers. + +A **nil budget = unlimited**, so all existing code is completely unaffected. + +### Quick start + +```go +b := plush.NewBudget(10_000) +ctx := plush.NewContext() +ctx.Set("products", products) +ctx.WithBudget(b) + +html, err := plush.Render(tmpl, ctx) +if errors.Is(err, plush.ErrBudgetExceeded) { + log.Printf("budget exceeded: used=%d remaining=%d", b.Used(), b.Remaining()) + return errorPage() +} + +// One-liner convenience wrapper +html, err = plush.RenderWithBudget(tmpl, 10_000, ctx) +``` + +### Default operation costs + +| Operation | Default cost | +|---|---| +| Loop iteration | 1 | +| Helper / function call | 5 | +| Filter call | 3 | +| Partial / sub-render | 10 | +| Condition check (`if`) | 1 | +| Variable assignment | 0 | +| Object traversal (per segment) | 1 | + +### Custom costs + +Pass a `BudgetCosts` struct to override any cost: + +```go +costs := plush.ZeroCosts() // start from all-zero +costs.LoopIteration = 1 +costs.SubRender = 25 + +html, err = plush.RenderWithBudgetConfig(tmpl, 5_000, costs, ctx) +``` + +### Per-function costs + +Override the cost for individual functions registered in the context: + +```go +costs := plush.DefaultBudgetCosts() +costs.FunctionCosts = map[string]int64{ + "expensiveQuery": 50, // charged 50 per call instead of the default 5 + "cheapHelper": 1, +} + +html, err = plush.RenderWithBudgetConfig(tmpl, 10_000, costs, ctx) +``` + +Functions not listed in `FunctionCosts` fall back to the `HelperCall` cost. + +### Stats report + +After rendering, call `b.Stats()` to see exactly where the budget was spent: + +```go +b := plush.NewBudget(10_000) +ctx.WithBudget(b) +plush.Render(tmpl, ctx) + +s := b.Stats() +fmt.Printf("total=%d loops=%d calls=%d conditions=%d\n", + s.TotalUsed, s.LoopIterations, s.FunctionCalls, s.ConditionChecks) + +for name, units := range s.ByFunction { + fmt.Printf(" %s: %d units\n", name, units) +} +``` + +`BudgetStats` fields: + +| Field | What it measures | +|---|---| +| `TotalUsed` | Sum of all units spent | +| `LoopIterations` | Units from loop iterations | +| `FunctionCalls` | Units from all function/helper calls | +| `FilterCalls` | Units from filter calls | +| `SubRenders` | Units from partial renders | +| `ConditionChecks` | Units from `if`/`unless` evaluations | +| `Assignments` | Units from variable assignments | +| `ObjectTraversals` | Units from dot-notation traversal | +| `ByFunction` | Per-function breakdown (map of name → units) | diff --git a/budget.go b/budget.go new file mode 100644 index 0000000..59209d8 --- /dev/null +++ b/budget.go @@ -0,0 +1,217 @@ +package plush + +import ( + "errors" + "sync" + "sync/atomic" +) + +// ErrBudgetExceeded is returned when a render exhausts its budget. +var ErrBudgetExceeded = errors.New("render budget exceeded") + +// BudgetStats is a snapshot of work units consumed per operation category. +// Retrieve it after rendering via b.Stats(). +type BudgetStats struct { + // TotalUsed is the sum of all units spent (equals b.Used()). + TotalUsed int64 + // LoopIterations is total units charged by loop iterations. + LoopIterations int64 + // FunctionCalls is total units charged by all function/helper calls. + FunctionCalls int64 + // FilterCalls is total units charged by filter calls. + FilterCalls int64 + // SubRenders is total units charged by partial/snippet renders. + SubRenders int64 + // ConditionChecks is total units charged by if/unless evaluations. + ConditionChecks int64 + // Assignments is total units charged by variable assignments. + Assignments int64 + // ObjectTraversals is total units charged by dot-notation traversal. + ObjectTraversals int64 + // ByFunction breaks FunctionCalls down by name for calls made via + // SpendFunctionCall. Functions without a FunctionCosts override appear + // here using the generic HelperCall cost. + ByFunction map[string]int64 +} + +// Budget tracks render work units during template evaluation. +// A nil Budget is always unlimited — zero breaking changes. +type Budget struct { + limit int64 + counter atomic.Int64 + costs BudgetCosts + + // per-category stat counters — all lock-free + statLoop atomic.Int64 + statFunction atomic.Int64 // total of all function/helper calls + statFilter atomic.Int64 + statSubRender atomic.Int64 + statCondition atomic.Int64 + statAssign atomic.Int64 + statTraversal atomic.Int64 + + // per-function breakdown — mutex-protected plain map + statFuncsMu sync.Mutex + statFuncsMap map[string]int64 +} + +// NewBudget creates a Budget with a limit and default costs. +func NewBudget(limit int64) *Budget { + return &Budget{ + limit: limit, + costs: DefaultBudgetCosts(), + statFuncsMap: make(map[string]int64), + } +} + +// NewBudgetWithCosts creates a Budget with fully custom per-operation costs. +func NewBudgetWithCosts(limit int64, costs BudgetCosts) *Budget { + return &Budget{ + limit: limit, + costs: costs, + statFuncsMap: make(map[string]int64), + } +} + +// WithCosts replaces the cost configuration. Returns self for chaining. +func (b *Budget) WithCosts(costs BudgetCosts) *Budget { + b.costs = costs + return b +} + +// Costs returns the active cost configuration. +func (b *Budget) Costs() BudgetCosts { + return b.costs +} + +// Used returns total units consumed so far. +func (b *Budget) Used() int64 { + return b.counter.Load() +} + +// Remaining returns units left before the limit is hit. +func (b *Budget) Remaining() int64 { + r := b.limit - b.counter.Load() + if r < 0 { + return 0 + } + return r +} + +// Stats returns a snapshot of work units consumed per operation category. +// Safe to call at any point during or after rendering. +func (b *Budget) Stats() BudgetStats { + if b == nil { + return BudgetStats{} + } + s := BudgetStats{ + TotalUsed: b.counter.Load(), + LoopIterations: b.statLoop.Load(), + FunctionCalls: b.statFunction.Load(), + FilterCalls: b.statFilter.Load(), + SubRenders: b.statSubRender.Load(), + ConditionChecks: b.statCondition.Load(), + Assignments: b.statAssign.Load(), + ObjectTraversals: b.statTraversal.Load(), + ByFunction: make(map[string]int64), + } + b.statFuncsMu.Lock() + for k, v := range b.statFuncsMap { + s.ByFunction[k] = v + } + b.statFuncsMu.Unlock() + return s +} + +// SpendLoop spends the loop iteration cost. +func (b *Budget) SpendLoop() error { + if b == nil { + return nil + } + b.statLoop.Add(b.costs.LoopIteration) + return b.spend(b.costs.LoopIteration) +} + +// SpendHelperCall spends the helper call cost. +func (b *Budget) SpendHelperCall() error { + if b == nil { + return nil + } + b.statFunction.Add(b.costs.HelperCall) + return b.spend(b.costs.HelperCall) +} + +// SpendFilter spends the filter call cost. +func (b *Budget) SpendFilter() error { + if b == nil { + return nil + } + b.statFilter.Add(b.costs.FilterCall) + return b.spend(b.costs.FilterCall) +} + +// SpendSubRender spends the sub-render cost. +func (b *Budget) SpendSubRender() error { + if b == nil { + return nil + } + b.statSubRender.Add(b.costs.SubRender) + return b.spend(b.costs.SubRender) +} + +// SpendCondition spends the condition check cost. +func (b *Budget) SpendCondition() error { + if b == nil { + return nil + } + b.statCondition.Add(b.costs.ConditionCheck) + return b.spend(b.costs.ConditionCheck) +} + +// SpendAssignment spends the assignment cost. +func (b *Budget) SpendAssignment() error { + if b == nil { + return nil + } + b.statAssign.Add(b.costs.Assignment) + return b.spend(b.costs.Assignment) +} + +// SpendFunctionCall spends the cost for a named function call. +// Uses FunctionCosts[name] if set, otherwise falls back to HelperCall cost. +func (b *Budget) SpendFunctionCall(name string) error { + if b == nil { + return nil + } + cost := b.costs.HelperCall + if c, ok := b.costs.FunctionCosts[name]; ok { + cost = c + } + b.statFunction.Add(cost) + b.statFuncsMu.Lock() + b.statFuncsMap[name] += cost + b.statFuncsMu.Unlock() + return b.spend(cost) +} + +// SpendObjectTraversal spends ObjectTraversal * segments units. +// e.g. product.variants.first = 3 segments → costs ObjectTraversal * 3 +func (b *Budget) SpendObjectTraversal(segments int) error { + if b == nil { + return nil + } + units := b.costs.ObjectTraversal * int64(segments) + b.statTraversal.Add(units) + return b.spend(units) +} + +// spend is the internal hot path. Uses atomic add with no locks. +func (b *Budget) spend(units int64) error { + if b == nil || units == 0 { + return nil + } + if b.counter.Add(units) > b.limit { + return ErrBudgetExceeded + } + return nil +} diff --git a/budget_config.go b/budget_config.go new file mode 100644 index 0000000..65e4794 --- /dev/null +++ b/budget_config.go @@ -0,0 +1,58 @@ +package plush + +// BudgetCosts defines the work-unit cost for each operation type. +type BudgetCosts struct { + // LoopIteration is spent once per for-loop iteration. + // Default: 1 + LoopIteration int64 + + // HelperCall is spent each time a registered helper is invoked. + // Default: 5 + HelperCall int64 + + // FilterCall is spent per filter applied (sort, map, where). + // Default: 3 + FilterCall int64 + + // SubRender is spent each time a partial/snippet is rendered. + // Default: 10 + SubRender int64 + + // ConditionCheck is spent per if/unless/case evaluation. + // Default: 1 + ConditionCheck int64 + + // Assignment is spent per variable assignment. + // Default: 0 (free — rarely the bottleneck) + Assignment int64 + + // ObjectTraversal is spent per dot-notation segment accessed. + // e.g. product.variants.first = 3 segments = 3 units + // Default: 1 + ObjectTraversal int64 + + // FunctionCosts overrides the default HelperCall cost for specific named + // functions. The key is the function name as registered in the context. + // If a name is present here, its cost is used instead of HelperCall. + // e.g. costs.FunctionCosts = map[string]int64{"expensiveQuery": 50} + FunctionCosts map[string]int64 +} + +// DefaultBudgetCosts returns recommended production defaults. +func DefaultBudgetCosts() BudgetCosts { + return BudgetCosts{ + LoopIteration: 1, + HelperCall: 5, + FilterCall: 3, + SubRender: 10, + ConditionCheck: 1, + Assignment: 0, + ObjectTraversal: 1, + } +} + +// ZeroCosts returns all-zero costs. +// Useful for isolating one operation type in tests. +func ZeroCosts() BudgetCosts { + return BudgetCosts{} +} diff --git a/budget_test.go b/budget_test.go new file mode 100644 index 0000000..6d5de3a --- /dev/null +++ b/budget_test.go @@ -0,0 +1,285 @@ +package plush + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// --- Core budget enforcement --- + +func TestBudget_LoopExceedsLimit(t *testing.T) { + r := require.New(t) + // 5 iterations, each costs 1, limit is 3 → should exceed + tmpl := `<% for (i,v) in items { } %>` + ctx := NewContext() + ctx.Set("items", []int{1, 2, 3, 4, 5}) + + _, err := RenderWithBudget(tmpl, 3, ctx) + r.True(errors.Is(err, ErrBudgetExceeded), "expected ErrBudgetExceeded, got %v", err) +} + +func TestBudget_LoopWithinLimit(t *testing.T) { + r := require.New(t) + tmpl := `<% for (i,v) in items { } %>` + ctx := NewContext() + ctx.Set("items", []int{1, 2, 3}) + + _, err := RenderWithBudget(tmpl, 100, ctx) + r.NoError(err) +} + +// --- Nil budget = unlimited (backwards compat) --- + +func TestBudget_NilIsUnlimited(t *testing.T) { + r := require.New(t) + tmpl := `<% for (i,v) in items { } %>` + ctx := NewContext() + ctx.Set("items", []int{1, 2, 3, 4, 5}) + + _, err := Render(tmpl, ctx) // no budget attached + r.NoError(err, "unlimited render should not fail") +} + +// --- Zero cost fields are skipped --- + +func TestBudget_ZeroCostNeverExceeds(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.LoopIteration = 0 // free + + tmpl := `<% for (i,v) in items { } %>` + ctx := NewContext() + ctx.Set("items", make([]int, 10_000)) + + _, err := RenderWithBudgetConfig(tmpl, 1, costs, ctx) + r.NoError(err, "zero cost loop should never exceed") +} + +// --- Custom helper call costs --- + +func TestBudget_CustomHelperCost(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.HelperCall = 100 // each call costs 100 + + // 2 calls = 200 units, limit is 150 → should exceed + tmpl := `<%= myHelper() %><%= myHelper() %>` + ctx := NewContext() + ctx.Set("myHelper", func() string { return "ok" }) + + _, err := RenderWithBudgetConfig(tmpl, 150, costs, ctx) + r.True(errors.Is(err, ErrBudgetExceeded), "expected ErrBudgetExceeded, got %v", err) +} + +// --- Remaining / Used --- + +func TestBudget_UsedAndRemaining(t *testing.T) { + r := require.New(t) + b := NewBudget(100) + ctx := NewContext() + ctx.Set("items", []int{1, 2, 3}) // 3 loop iterations = 3 units + ctx.WithBudget(b) + + tmpl := `<% for (i,v) in items { } %>` + Render(tmpl, ctx) + + r.Greater(b.Used(), int64(0), "expected some units to be used") + r.Less(b.Remaining(), int64(100), "remaining should be less than limit after render") +} + +// --- Condition check cost --- + +func TestBudget_ConditionExceedsLimit(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.ConditionCheck = 5 + + // Two if-checks = 10 units, limit is 7 → second exceeds + tmpl := `<% if (true) { %>a<% } %><% if (true) { %>b<% } %>` + ctx := NewContext() + + _, err := RenderWithBudgetConfig(tmpl, 7, costs, ctx) + r.True(errors.Is(err, ErrBudgetExceeded), "expected ErrBudgetExceeded, got %v", err) +} + +// --- Sub-render shares parent budget (unit test on Budget directly) --- + +func TestBudget_SubRenderSharesParentBudget(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.SubRender = 50 // each snippet costs 50 + + b := NewBudgetWithCosts(75, costs) // limit 75 — second snippet exceeds + ctx := NewContext() + ctx.WithBudget(b) + + err1 := b.SpendSubRender() // 50 — ok + err2 := b.SpendSubRender() // 100 — exceeds 75 + + r.NoError(err1, "first snippet should succeed") + r.True(errors.Is(err2, ErrBudgetExceeded), "second snippet should exceed budget, got %v", err2) +} + +// --- NewBudget / WithCosts / Costs --- + +func TestBudget_WithCosts(t *testing.T) { + r := require.New(t) + b := NewBudget(1000) + custom := ZeroCosts() + custom.HelperCall = 42 + b.WithCosts(custom) + + r.Equal(int64(42), b.Costs().HelperCall) +} + +func TestBudget_NewBudgetWithCosts(t *testing.T) { + r := require.New(t) + costs := DefaultBudgetCosts() + b := NewBudgetWithCosts(500, costs) + + r.Equal(int64(500), b.Remaining()) + r.Equal(int64(0), b.Used()) +} + +// --- Per-function cost override --- + +func TestBudget_FunctionCostOverride_Exceeds(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.FunctionCosts = map[string]int64{ + "expensive": 60, // overrides generic HelperCall + } + + // 2 calls × 60 = 120, limit 100 → second call exceeds + tmpl := `<%= expensive() %><%= expensive() %>` + ctx := NewContext() + ctx.Set("expensive", func() string { return "x" }) + + _, err := RenderWithBudgetConfig(tmpl, 100, costs, ctx) + r.True(errors.Is(err, ErrBudgetExceeded), "expected ErrBudgetExceeded, got %v", err) +} + +func TestBudget_FunctionCostOverride_FallsBackToHelperCall(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.HelperCall = 10 + costs.FunctionCosts = map[string]int64{ + "cheap": 1, // cheap function — does NOT affect "other" + } + + // "other" has no override, falls back to HelperCall=10; 2 calls = 20 > limit 15 + tmpl := `<%= other() %><%= other() %>` + ctx := NewContext() + ctx.Set("other", func() string { return "y" }) + + _, err := RenderWithBudgetConfig(tmpl, 15, costs, ctx) + r.True(errors.Is(err, ErrBudgetExceeded), "expected ErrBudgetExceeded, got %v", err) +} + +func TestBudget_FunctionCostOverride_CheapFunctionDoesNotExceed(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.HelperCall = 100 // default is huge + costs.FunctionCosts = map[string]int64{ + "cheap": 1, // this function is cheap + } + + // 5 cheap calls = 5, limit 50 → fine + tmpl := `<%= cheap() %><%= cheap() %><%= cheap() %><%= cheap() %><%= cheap() %>` + ctx := NewContext() + ctx.Set("cheap", func() string { return "ok" }) + + _, err := RenderWithBudgetConfig(tmpl, 50, costs, ctx) + r.NoError(err) +} + +// --- Stats report --- + +func TestBudget_Stats_LoopIterations(t *testing.T) { + r := require.New(t) + b := NewBudget(1_000) + ctx := NewContext() + ctx.Set("items", []int{1, 2, 3}) + ctx.WithBudget(b) + + _, err := Render(`<% for (i,v) in items { } %>`, ctx) + r.NoError(err) + + s := b.Stats() + r.Equal(int64(3), s.LoopIterations, "3 iterations × cost 1 = 3") + r.Equal(int64(3), s.TotalUsed) + r.Equal(int64(0), s.FunctionCalls) + r.Equal(int64(0), s.ConditionChecks) +} + +func TestBudget_Stats_FunctionCalls(t *testing.T) { + r := require.New(t) + b := NewBudget(1_000) + ctx := NewContext() + ctx.Set("greet", func() string { return "hi" }) + ctx.WithBudget(b) + + _, err := Render(`<%= greet() %><%= greet() %><%= greet() %>`, ctx) + r.NoError(err) + + s := b.Stats() + r.Equal(int64(15), s.FunctionCalls, "3 calls × default HelperCall cost 5 = 15") + r.Equal(int64(15), s.TotalUsed) + r.Equal(int64(15), s.ByFunction["greet"]) +} + +func TestBudget_Stats_ByFunctionPerFunctionCost(t *testing.T) { + r := require.New(t) + costs := ZeroCosts() + costs.FunctionCosts = map[string]int64{ + "heavy": 20, + "light": 2, + } + b := NewBudgetWithCosts(1_000, costs) + ctx := NewContext() + ctx.Set("heavy", func() string { return "h" }) + ctx.Set("light", func() string { return "l" }) + ctx.WithBudget(b) + + _, err := Render(`<%= heavy() %><%= light() %><%= light() %>`, ctx) + r.NoError(err) + + s := b.Stats() + r.Equal(int64(20), s.ByFunction["heavy"], "1 call × 20") + r.Equal(int64(4), s.ByFunction["light"], "2 calls × 2") + r.Equal(int64(24), s.FunctionCalls) + r.Equal(int64(24), s.TotalUsed) +} + +func TestBudget_Stats_MixedOperations(t *testing.T) { + r := require.New(t) + costs := BudgetCosts{ + LoopIteration: 2, + HelperCall: 10, + ConditionCheck: 3, + } + b := NewBudgetWithCosts(1_000, costs) + ctx := NewContext() + ctx.Set("calc", func() string { return "x" }) + ctx.WithBudget(b) + + // 2 calc calls × 10 = 20, 1 if × 3 = 3 → total 23 + tmpl := `<%= calc() %><%= calc() %><% if (true) { %>ok<% } %>` + _, err := Render(tmpl, ctx) + r.NoError(err) + + s := b.Stats() + r.Equal(int64(20), s.FunctionCalls) + r.Equal(int64(3), s.ConditionChecks) + r.Equal(int64(0), s.LoopIterations) + r.Equal(int64(23), s.TotalUsed) +} + +func TestBudget_Stats_NilBudgetReturnsZero(t *testing.T) { + r := require.New(t) + var b *Budget + s := b.Stats() + r.Equal(BudgetStats{}, s) +} diff --git a/compiler.go b/compiler.go index a8ae547..a0827e9 100644 --- a/compiler.go +++ b/compiler.go @@ -43,6 +43,14 @@ type compiler struct { positionStartEnds []HoleMarker } +// budget returns the active Budget from the current context, or nil if unlimited. +func (c *compiler) budget() *Budget { + if ctx, ok := c.ctx.(*Context); ok { + return ctx.Budget() + } + return nil +} + func (c *compiler) compile() (string, error) { bb := builderPool.Get().(*strings.Builder) bb.Reset() @@ -200,6 +208,9 @@ func (c *compiler) evalExpression(node ast.Expression) (interface{}, error) { } func (c *compiler) evalAssignExpression(node *ast.AssignExpression) (interface{}, error) { + if err := c.budget().SpendAssignment(); err != nil { + return nil, err + } v, err := c.evalExpression(node.Value) if err != nil { return nil, err @@ -255,6 +266,9 @@ func (c *compiler) evalPrefixExpression(node *ast.PrefixExpression) (interface{} } func (c *compiler) evalIfExpression(node *ast.IfExpression) (interface{}, error) { + if err := c.budget().SpendCondition(); err != nil { + return nil, err + } octx := c.ctx.(*Context) defer func() { c.ctx = octx @@ -436,6 +450,9 @@ func (c *compiler) evalHashLiteral(node *ast.HashLiteral) (interface{}, error) { } func (c *compiler) evalLetStatement(node *ast.LetStatement) (interface{}, error) { + if err := c.budget().SpendAssignment(); err != nil { + return nil, err + } v, err := c.evalExpression(node.Value) if err != nil { return nil, err @@ -447,6 +464,9 @@ func (c *compiler) evalLetStatement(node *ast.LetStatement) (interface{}, error) func (c *compiler) evalIdentifier(node *ast.Identifier) (interface{}, error) { if node.Callee != nil { + if err := c.budget().SpendObjectTraversal(1); err != nil { + return nil, err + } c, err := c.evalExpression(node.Callee) if err != nil { return nil, err @@ -703,6 +723,13 @@ func (c *compiler) stringsOperator(l string, r interface{}, op string) (interfac } func (c *compiler) evalCallExpression(node *ast.CallExpression) (interface{}, error) { + funcName := node.Function.String() + if i, ok := node.Function.(*ast.Identifier); ok { + funcName = i.Value + } + if err := c.budget().SpendFunctionCall(funcName); err != nil { + return nil, err + } var rv reflect.Value if node.Callee != nil { @@ -943,6 +970,9 @@ func (c *compiler) evalForExpression(node *ast.ForExpression) (interface{}, erro case reflect.Map: keys := riter.MapKeys() for i := 0; i < len(keys); i++ { + if err := c.budget().SpendLoop(); err != nil { + return nil, err + } k := keys[i] v := riter.MapIndex(k) c.ctx.Set(node.KeyName, k.Interface()) @@ -972,6 +1002,9 @@ func (c *compiler) evalForExpression(node *ast.ForExpression) (interface{}, erro } case reflect.Slice, reflect.Array: for i := 0; i < riter.Len(); i++ { + if err := c.budget().SpendLoop(); err != nil { + return nil, err + } v := riter.Index(i) c.ctx.Set(node.KeyName, i) c.ctx.Set(node.ValueName, v.Interface()) @@ -1006,6 +1039,9 @@ func (c *compiler) evalForExpression(node *ast.ForExpression) (interface{}, erro i := 0 ii := it.Next() for ii != nil { + if err := c.budget().SpendLoop(); err != nil { + return nil, err + } c.ctx.Set(node.KeyName, i) c.ctx.Set(node.ValueName, ii) diff --git a/context.go b/context.go index a2ac49c..9c3fdcc 100644 --- a/context.go +++ b/context.go @@ -12,9 +12,28 @@ var _ context.Context = &Context{} // Context holds all of the data for the template that is being rendered. type Context struct { context.Context - data *SymbolTable - outer *Context - moot *sync.RWMutex + data *SymbolTable + outer *Context + moot *sync.RWMutex + budget *Budget +} + +// WithBudget attaches a Budget to this context. Returns self for chaining. +func (c *Context) WithBudget(b *Budget) *Context { + c.budget = b + return c +} + +// Budget returns the active budget, walking up the outer chain. +// Returns nil if no budget is set (unlimited). +func (c *Context) Budget() *Budget { + if c.budget != nil { + return c.budget + } + if c.outer != nil { + return c.outer.Budget() + } + return nil } // New context containing the current context. Values set on the new context diff --git a/partial_helper.go b/partial_helper.go index 6bcc4c2..9ac38fb 100644 --- a/partial_helper.go +++ b/partial_helper.go @@ -24,6 +24,12 @@ func PartialHelper(name string, data map[string]interface{}, help HelperContext) return "", fmt.Errorf("invalid context. abort") } + if ctx, ok := help.Context.(*Context); ok { + if err := ctx.Budget().SpendSubRender(); err != nil { + return "", err + } + } + help.Context = help.New() for k, v := range data { help.Set(k, v) diff --git a/plush.go b/plush.go index 1f97409..d3bba03 100644 --- a/plush.go +++ b/plush.go @@ -97,6 +97,22 @@ func Parse(input ...string) (*Template, error) { return t, nil } +// RenderWithBudget renders a template and enforces a work-unit limit. +// Returns ErrBudgetExceeded if the template exhausts the budget. +// Existing Render() is completely unchanged. +func RenderWithBudget(input string, limit int64, ctx *Context) (string, error) { + b := NewBudget(limit) + ctx.WithBudget(b) + return Render(input, ctx) +} + +// RenderWithBudgetConfig renders with a fully custom cost configuration. +func RenderWithBudgetConfig(input string, limit int64, costs BudgetCosts, ctx *Context) (string, error) { + b := NewBudgetWithCosts(limit, costs) + ctx.WithBudget(b) + return Render(input, ctx) +} + func isHole(ctx hctx.Context) bool { if ctx.Value(holeTemplateFileKey) == nil { return false