From c2070db490896f7eba695f64e68df193983b597d Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Fri, 2 May 2025 23:51:45 -0400 Subject: [PATCH] Add periodic checkpointing during progress reporting. --- .../pkg/beam/runners/prism/internal/stage.go | 121 +++++++++++------- 1 file changed, 74 insertions(+), 47 deletions(-) diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go b/sdks/go/pkg/beam/runners/prism/internal/stage.go index a877a887ac1a..a1d84faf1498 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/stage.go +++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go @@ -91,10 +91,12 @@ type stage struct { baseProgTick atomic.Value // time.Duration } -// The minimum and maximum durations between each ProgressBundleRequest and split evaluation. const ( + // The minimum and maximum durations between each ProgressBundleRequest and split evaluation. minimumProgTick = 100 * time.Millisecond maximumProgTick = 30 * time.Second + // The number of ticks before triggering a checkpoint + checkpointTickCutoff = 10 ) func clampTick(dur time.Duration) time.Duration { @@ -177,6 +179,7 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c unsplit := true baseTick := s.baseProgTick.Load().(time.Duration) ticked := false + checkpointTickCount := 0 progTick := time.NewTicker(baseTick) defer progTick.Stop() var dataFinished, bundleFinished bool @@ -186,6 +189,7 @@ func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, c dataFinished = true } var resp *fnpb.ProcessBundleResponse + var residualRoots []*fnpb.DelayedBundleApplication progress: for { select { @@ -206,6 +210,7 @@ progress: } case <-progTick.C: ticked = true + checkpointTickCount += 1 resp, err := b.Progress(ctx, wk) if err != nil { slog.Debug("SDK Error from progress, aborting progress", "bundle", rb, "error", err.Error()) @@ -218,58 +223,75 @@ progress: } slog.Debug("progress report", "bundle", rb, "index", index, "prevIndex", previousIndex) + var fraction float64 + // Check if there has been any measurable progress by the input, or all output pcollections since last report. slow := previousIndex == index["index"] && previousTotalCount == index["totalCount"] + checkpointReady := checkpointTickCount >= checkpointTickCutoff if slow && unsplit { - slog.Debug("splitting report", "bundle", rb, "index", index) - sr, err := b.Split(ctx, wk, 0.5 /* fraction of remainder */, nil /* allowed splits */) - if err != nil { - slog.Warn("SDK Error from split, aborting splits", "bundle", rb, "error", err.Error()) - break progress - } - if sr.GetChannelSplits() == nil { - slog.Debug("SDK returned no splits", "bundle", rb) - unsplit = false - continue progress - } + fraction = 0.5 + } else if checkpointReady && unsplit { + // splitting on 0.0 fraction to make a checkpoint + fraction = 0.0 + // reset tickCount after scheduling a checkpoint + checkpointTickCount = 0 + } else { + previousIndex = index["index"] + previousTotalCount = index["totalCount"] + continue progress + } - // TODO sort out rescheduling primary Roots on bundle failure. - var residuals []engine.Residual - for _, rr := range sr.GetResidualRoots() { - ba := rr.GetApplication() - residuals = append(residuals, engine.Residual{Element: ba.GetElement()}) - if len(ba.GetElement()) == 0 { - slog.LogAttrs(context.TODO(), slog.LevelError, "returned empty residual application", slog.Any("bundle", rb)) - panic("sdk returned empty residual application") - } - // TODO what happens to output watermarks on splits? - } - if len(sr.GetChannelSplits()) != 1 { - slog.Warn("received non-single channel split", "bundle", rb) - } - cs := sr.GetChannelSplits()[0] - fr := cs.GetFirstResidualElement() - // The first residual can be after the end of data, so filter out those cases. - if b.EstimatedInputElements >= int(fr) { - b.EstimatedInputElements = int(fr) // Update the estimate for the next split. - // Split Residuals are returned right away for rescheduling. - em.ReturnResiduals(rb, int(fr), s.inputInfo, engine.Residuals{ - Data: residuals, - }) + // Do the split (fraction > 0) or checkpoint (fraction == 0) + slog.Debug("splitting report", "bundle", rb, "index", index) + sr, err := b.Split(ctx, wk, fraction /* fraction of remainder */, nil /* allowed splits */) + if err != nil { + slog.Warn("SDK Error from split, aborting splits", "bundle", rb, "error", err.Error()) + break progress + } + if sr.GetChannelSplits() == nil { + slog.Debug("SDK returned no splits", "bundle", rb) + unsplit = false + continue progress + } + // Save residual roots for checkpoint. After checkpointing is successful, + // the bundle will be marked as finished and no residual roots will be + // returned in ProcessBundleResponse. + if fraction == 0 { + residualRoots = sr.GetResidualRoots() + } + // TODO sort out rescheduling primary Roots on bundle failure. + var residuals []engine.Residual + for _, rr := range sr.GetResidualRoots() { + ba := rr.GetApplication() + residuals = append(residuals, engine.Residual{Element: ba.GetElement()}) + if len(ba.GetElement()) == 0 { + slog.LogAttrs(context.TODO(), slog.LevelError, "returned empty residual application", slog.Any("bundle", rb)) + panic("sdk returned empty residual application") } + // TODO what happens to output watermarks on splits? + } + if len(sr.GetChannelSplits()) != 1 { + slog.Warn("received non-single channel split", "bundle", rb) + } + cs := sr.GetChannelSplits()[0] + fr := cs.GetFirstResidualElement() + // The first residual can be after the end of data, so filter out those cases. + if b.EstimatedInputElements >= int(fr) { + b.EstimatedInputElements = int(fr) // Update the estimate for the next split. + // Split Residuals are returned right away for rescheduling. + em.ReturnResiduals(rb, int(fr), s.inputInfo, engine.Residuals{ + Data: residuals, + }) + } - // Any split means we're processing slower than desired, but splitting should increase - // throughput. Back off for this and other bundles for this stage - baseTime := s.baseProgTick.Load().(time.Duration) - newTime := clampTick(baseTime * 4) - if s.baseProgTick.CompareAndSwap(baseTime, newTime) { - progTick.Reset(newTime) - } else { - progTick.Reset(s.baseProgTick.Load().(time.Duration)) - } + // Any split means we're processing slower than desired, but splitting should increase + // throughput. Back off for this and other bundles for this stage + baseTime := s.baseProgTick.Load().(time.Duration) + newTime := clampTick(baseTime * 4) + if s.baseProgTick.CompareAndSwap(baseTime, newTime) { + progTick.Reset(newTime) } else { - previousIndex = index["index"] - previousTotalCount = index["totalCount"] + progTick.Reset(s.baseProgTick.Load().(time.Duration)) } } } @@ -290,11 +312,16 @@ progress: j.AddMetricShortIDs(md) } + // Use residual roots from ProcessBundleResponse if any. + // Otherwise, use residual roots from ProcessBundleSplitResponse if a checkpoint occurs. + if len(resp.GetResidualRoots()) > 0 { + residualRoots = resp.GetResidualRoots() + } // ProcessContinuation residuals are rescheduled after the specified delay. residuals := engine.Residuals{ MinOutputWatermarks: map[string]mtime.Time{}, } - for _, rr := range resp.GetResidualRoots() { + for _, rr := range residualRoots { ba := rr.GetApplication() if len(ba.GetElement()) == 0 { slog.LogAttrs(context.TODO(), slog.LevelError, "returned empty residual application", slog.Any("bundle", rb))