Skip to content

Commit 9f5f002

Browse files
Microbatch first last batch serial (#11072) (#11107)
* microbatch: split out first and last batch to run in serial * only run pre_hook on first batch, post_hook on last batch * refactor: internalize parallel to RunTask._submit_batch * Add optional `force_sequential` to `_submit_batch` to allow for skipping parallelism check * Force last batch to run sequentially * Force first batch to run sequentially * Remove batch_idx check in `should_run_in_parallel` `should_run_in_parallel` shouldn't, and no longer needs to, take into consideration where in batch exists in a larger context. The first and last batch for a microbatch model are now forced to run sequentially by `handle_microbatch_model` * Begin skipping batches if first batch fails * Write custom `on_skip` for `MicrobatchModelRunner` to better handle when batches are skipped This was necessary specifically because the default on skip set the `X of Y` part of the skipped log using the `node_index` and the `num_nodes`. If there was 2 nodes and we are on the 4th batch of the second node, we'd get a message like `SKIPPED 4 of 2...` which didn't make much sense. We're likely in a future commit going to add a custom event for logging the start, result, and skipping of batches for better readability of the logs. * Add microbatch pre-hook, post-hook, and sequential first/last batch tests * Fix/Add tests around first batch failure vs latter batch failure * Correct MicrobatchModelRunner.on_skip to handle skipping the entire node Previously `MicrobatchModelRunner.on_skip` only handled when a _batch_ of the model was being skipped. However, that method is also used when the entire microbatch model is being skipped due to an upstream node error. Because we previously _weren't_ handling this second case, it'd cause an unhandled runtime exception. Thus, we now need to check whether we're running a batch or not, and there is no batch, then use the super's on_skip method. * Correct conditional logic for setting pre- and post-hooks for batches Previously we were doing an if+elif for setting pre- and post-hooks for batches, where in the `if` matched if the batch wasn't the first batch, and the `elif` matched if the batch wasn't the last batch. The issue with this is that if the `if` was hit, the `elif` _wouldn't_ be hit. This caused the first batch to appropriately not run the `post-hook` but then every hook after would run the `post-hook`. * Add two new event types `LogStartBatch` and `LogBatchResult` * Update MicrobatchModelRunner to use new batch specific log events * Fix event testing * Update microbatch integration tests to catch batch specific event types --------- Co-authored-by: Quigley Malcolm <[email protected]> (cherry picked from commit 03fdb4c) Co-authored-by: Michelle Ark <[email protected]>
1 parent 4e74e69 commit 9f5f002

File tree

8 files changed

+575
-219
lines changed

8 files changed

+575
-219
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
kind: Features
2+
body: Ensure pre/post hooks only run on first/last batch respectively for microbatch
3+
model batches
4+
time: 2024-12-06T19:53:08.928793-06:00
5+
custom:
6+
Author: MichelleArk QMalcolm
7+
Issue: 11094 11104

core/dbt/events/core_types.proto

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,35 @@ message MicrobatchExecutionDebugMsg {
16901690
MicrobatchExecutionDebug data = 2;
16911691
}
16921692

1693+
// Q045
1694+
message LogStartBatch {
1695+
NodeInfo node_info = 1;
1696+
string description = 2;
1697+
int32 batch_index = 3;
1698+
int32 total_batches = 4;
1699+
}
1700+
1701+
message LogStartBatchMsg {
1702+
CoreEventInfo info = 1;
1703+
LogStartBatch data = 2;
1704+
}
1705+
1706+
// Q046
1707+
message LogBatchResult {
1708+
NodeInfo node_info = 1;
1709+
string description = 2;
1710+
string status = 3;
1711+
int32 batch_index = 4;
1712+
int32 total_batches = 5;
1713+
float execution_time = 6;
1714+
Group group = 7;
1715+
}
1716+
1717+
message LogBatchResultMsg {
1718+
CoreEventInfo info = 1;
1719+
LogBatchResult data = 2;
1720+
}
1721+
16931722
// W - Node testing
16941723

16951724
// Skipped W001

core/dbt/events/core_types_pb2.py

Lines changed: 179 additions & 171 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/dbt/events/types.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,51 @@ def message(self) -> str:
17101710
return self.msg
17111711

17121712

1713+
class LogStartBatch(InfoLevel):
1714+
def code(self) -> str:
1715+
return "Q045"
1716+
1717+
def message(self) -> str:
1718+
msg = f"START {self.description}"
1719+
1720+
# TODO update common so that we can append "batch" in `format_fancy_output_line`
1721+
formatted = format_fancy_output_line(
1722+
msg=msg,
1723+
status="RUN",
1724+
index=self.batch_index,
1725+
total=self.total_batches,
1726+
)
1727+
return f"Batch {formatted}"
1728+
1729+
1730+
class LogBatchResult(DynamicLevel):
1731+
def code(self) -> str:
1732+
return "Q046"
1733+
1734+
def message(self) -> str:
1735+
if self.status == "error":
1736+
info = "ERROR creating"
1737+
status = red(self.status.upper())
1738+
elif self.status == "skipped":
1739+
info = "SKIP"
1740+
status = yellow(self.status.upper())
1741+
else:
1742+
info = "OK created"
1743+
status = green(self.status)
1744+
1745+
msg = f"{info} {self.description}"
1746+
1747+
# TODO update common so that we can append "batch" in `format_fancy_output_line`
1748+
formatted = format_fancy_output_line(
1749+
msg=msg,
1750+
status=status,
1751+
index=self.batch_index,
1752+
total=self.total_batches,
1753+
execution_time=self.execution_time,
1754+
)
1755+
return f"Batch {formatted}"
1756+
1757+
17131758
# =======================================================
17141759
# W - Node testing
17151760
# =======================================================

core/dbt/task/run.py

Lines changed: 131 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
from dbt.contracts.graph.nodes import BatchContext, HookNode, ModelNode, ResultNode
3131
from dbt.events.types import (
3232
GenericExceptionOnRun,
33+
LogBatchResult,
3334
LogHookEndLine,
3435
LogHookStartLine,
3536
LogModelResult,
37+
LogStartBatch,
3638
LogStartLine,
3739
MicrobatchExecutionDebug,
3840
)
@@ -397,15 +399,18 @@ def print_batch_result_line(
397399
if result.status == NodeStatus.Error:
398400
status = result.status
399401
level = EventLevel.ERROR
402+
elif result.status == NodeStatus.Skipped:
403+
status = result.status
404+
level = EventLevel.INFO
400405
else:
401406
status = result.message
402407
level = EventLevel.INFO
403408
fire_event(
404-
LogModelResult(
409+
LogBatchResult(
405410
description=description,
406411
status=status,
407-
index=self.batch_idx + 1,
408-
total=len(self.batches),
412+
batch_index=self.batch_idx + 1,
413+
total_batches=len(self.batches),
409414
execution_time=result.execution_time,
410415
node_info=self.node.node_info,
411416
group=group,
@@ -423,10 +428,10 @@ def print_batch_start_line(self) -> None:
423428

424429
batch_description = self.describe_batch(batch_start)
425430
fire_event(
426-
LogStartLine(
431+
LogStartBatch(
427432
description=batch_description,
428-
index=self.batch_idx + 1,
429-
total=len(self.batches),
433+
batch_index=self.batch_idx + 1,
434+
total_batches=len(self.batches),
430435
node_info=self.node.node_info,
431436
)
432437
)
@@ -472,6 +477,25 @@ def merge_batch_results(self, result: RunResult, batch_results: List[RunResult])
472477
if self.node.previous_batch_results is not None:
473478
result.batch_results.successful += self.node.previous_batch_results.successful
474479

480+
def on_skip(self):
481+
# If node.batch is None, then we're dealing with skipping of the entire node
482+
if self.batch_idx is None:
483+
return super().on_skip()
484+
else:
485+
result = RunResult(
486+
node=self.node,
487+
status=RunStatus.Skipped,
488+
timing=[],
489+
thread_id=threading.current_thread().name,
490+
execution_time=0.0,
491+
message="SKIPPED",
492+
adapter_response={},
493+
failures=1,
494+
batch_results=BatchResults(failed=[self.batches[self.batch_idx]]),
495+
)
496+
self.print_batch_result_line(result=result)
497+
return result
498+
475499
def _build_succesful_run_batch_result(
476500
self,
477501
model: ModelNode,
@@ -602,13 +626,10 @@ def _has_relation(self, model) -> bool:
602626
)
603627
return relation is not None
604628

605-
def _should_run_in_parallel(
606-
self,
607-
relation_exists: bool,
608-
) -> bool:
629+
def should_run_in_parallel(self) -> bool:
609630
if not self.adapter.supports(Capability.MicrobatchConcurrency):
610631
run_in_parallel = False
611-
elif not relation_exists:
632+
elif not self.relation_exists:
612633
# If the relation doesn't exist, we can't run in parallel
613634
run_in_parallel = False
614635
elif self.node.config.concurrent_batches is not None:
@@ -703,52 +724,122 @@ def handle_microbatch_model(
703724
runner: MicrobatchModelRunner,
704725
pool: ThreadPool,
705726
) -> RunResult:
706-
# Initial run computes batch metadata, unless model is skipped
727+
# Initial run computes batch metadata
707728
result = self.call_runner(runner)
729+
batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists
730+
731+
# Return early if model should be skipped, or there are no batches to execute
708732
if result.status == RunStatus.Skipped:
709733
return result
734+
elif len(runner.batches) == 0:
735+
return result
710736

711737
batch_results: List[RunResult] = []
712-
713-
# Execute batches serially until a relation exists, at which point future batches are run in parallel
714-
relation_exists = runner.relation_exists
715738
batch_idx = 0
716-
while batch_idx < len(runner.batches):
717-
batch_runner = MicrobatchModelRunner(
718-
self.config, runner.adapter, deepcopy(runner.node), self.run_count, self.num_nodes
719-
)
720-
batch_runner.set_batch_idx(batch_idx)
721-
batch_runner.set_relation_exists(relation_exists)
722-
batch_runner.set_batches(runner.batches)
723-
724-
if runner._should_run_in_parallel(relation_exists):
725-
fire_event(
726-
MicrobatchExecutionDebug(
727-
msg=f"{batch_runner.describe_batch} is being run concurrently"
728-
)
729-
)
730-
self._submit(pool, [batch_runner], batch_results.append)
731-
else:
732-
fire_event(
733-
MicrobatchExecutionDebug(
734-
msg=f"{batch_runner.describe_batch} is being run sequentially"
735-
)
736-
)
737-
batch_results.append(self.call_runner(batch_runner))
738-
relation_exists = batch_runner.relation_exists
739739

740+
# Run first batch not in parallel
741+
relation_exists = self._submit_batch(
742+
node=node,
743+
adapter=runner.adapter,
744+
relation_exists=relation_exists,
745+
batches=batches,
746+
batch_idx=batch_idx,
747+
batch_results=batch_results,
748+
pool=pool,
749+
force_sequential_run=True,
750+
)
751+
batch_idx += 1
752+
skip_batches = batch_results[0].status != RunStatus.Success
753+
754+
# Run all batches except first and last batch, in parallel if possible
755+
while batch_idx < len(runner.batches) - 1:
756+
relation_exists = self._submit_batch(
757+
node=node,
758+
adapter=runner.adapter,
759+
relation_exists=relation_exists,
760+
batches=batches,
761+
batch_idx=batch_idx,
762+
batch_results=batch_results,
763+
pool=pool,
764+
skip=skip_batches,
765+
)
740766
batch_idx += 1
741767

742-
# Wait until all batches have completed
743-
while len(batch_results) != len(runner.batches):
768+
# Wait until all submitted batches have completed
769+
while len(batch_results) != batch_idx:
744770
pass
771+
# Final batch runs once all others complete to ensure post_hook runs at the end
772+
self._submit_batch(
773+
node=node,
774+
adapter=runner.adapter,
775+
relation_exists=relation_exists,
776+
batches=batches,
777+
batch_idx=batch_idx,
778+
batch_results=batch_results,
779+
pool=pool,
780+
force_sequential_run=True,
781+
skip=skip_batches,
782+
)
745783

784+
# Finalize run: merge results, track model run, and print final result line
746785
runner.merge_batch_results(result, batch_results)
747786
track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter)
748787
runner.print_result_line(result)
749788

750789
return result
751790

791+
def _submit_batch(
792+
self,
793+
node: ModelNode,
794+
adapter: BaseAdapter,
795+
relation_exists: bool,
796+
batches: Dict[int, BatchType],
797+
batch_idx: int,
798+
batch_results: List[RunResult],
799+
pool: ThreadPool,
800+
force_sequential_run: bool = False,
801+
skip: bool = False,
802+
):
803+
node_copy = deepcopy(node)
804+
# Only run pre_hook(s) for first batch
805+
if batch_idx != 0:
806+
node_copy.config.pre_hook = []
807+
808+
# Only run post_hook(s) for last batch
809+
if batch_idx != len(batches) - 1:
810+
node_copy.config.post_hook = []
811+
812+
# TODO: We should be doing self.get_runner, however doing so
813+
# currently causes the tracking of how many nodes there are to
814+
# increment when we don't want it to
815+
batch_runner = MicrobatchModelRunner(
816+
self.config, adapter, node_copy, self.run_count, self.num_nodes
817+
)
818+
batch_runner.set_batch_idx(batch_idx)
819+
batch_runner.set_relation_exists(relation_exists)
820+
batch_runner.set_batches(batches)
821+
822+
if skip:
823+
batch_runner.do_skip()
824+
825+
if not force_sequential_run and batch_runner.should_run_in_parallel():
826+
fire_event(
827+
MicrobatchExecutionDebug(
828+
msg=f"{batch_runner.describe_batch} is being run concurrently"
829+
)
830+
)
831+
self._submit(pool, [batch_runner], batch_results.append)
832+
else:
833+
fire_event(
834+
MicrobatchExecutionDebug(
835+
msg=f"{batch_runner.describe_batch} is being run sequentially"
836+
)
837+
)
838+
batch_results.append(self.call_runner(batch_runner))
839+
relation_exists = batch_runner.relation_exists
840+
841+
return relation_exists
842+
752843
def _hook_keyfunc(self, hook: HookNode) -> Tuple[str, Optional[int]]:
753844
package_name = hook.package_name
754845
if package_name == self.config.project_name:

0 commit comments

Comments
 (0)