Skip to content

Commit 56d0f34

Browse files
colin-hosamster25
andauthored
feat: Flotilla stage redesign (#4306)
## Changes Made Add tentative redesign (and comment for rationale) for stages, subject to change. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review) --------- Co-authored-by: Sammy Sidhu <[email protected]>
1 parent 91cba48 commit 56d0f34

File tree

12 files changed

+367
-150
lines changed

12 files changed

+367
-150
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/daft-distributed/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ common-daft-config = {path = "../common/daft-config", default-features = false}
44
common-error = {path = "../common/error", default-features = false}
55
common-partitioning = {path = "../common/partitioning", default-features = false}
66
common-treenode = {path = "../common/treenode", default-features = false}
7+
daft-dsl = {path = "../daft-dsl", default-features = false}
78
daft-local-plan = {path = "../daft-local-plan", default-features = false}
89
daft-logical-plan = {path = "../daft-logical-plan", default-features = false}
10+
daft-schema = {path = "../daft-schema", default-features = false}
911
futures = {workspace = true}
1012
pyo3 = {workspace = true, optional = true}
1113
pyo3-async-runtimes = {workspace = true, optional = true}

src/daft-distributed/src/pipeline_node/collect.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ use crate::{
1111
stage::StageContext,
1212
};
1313

14+
#[allow(dead_code)]
1415
pub(crate) struct CollectNode {
1516
local_physical_plans: Vec<LocalPhysicalPlanRef>,
1617
children: Vec<Box<dyn DistributedPipelineNode>>,
1718
input_psets: HashMap<String, Vec<PartitionRef>>,
1819
}
1920

2021
impl CollectNode {
22+
#[allow(dead_code)]
2123
pub fn new(
2224
local_physical_plans: Vec<LocalPhysicalPlanRef>,
2325
children: Vec<Box<dyn DistributedPipelineNode>>,
@@ -37,14 +39,15 @@ impl CollectNode {
3739
}
3840
}
3941

42+
#[allow(dead_code)]
4043
async fn execution_loop(
4144
_task_dispatcher_handle: TaskDispatcherHandle,
4245
_local_physical_plans: Vec<LocalPhysicalPlanRef>,
4346
_psets: HashMap<String, Vec<PartitionRef>>,
4447
_input_node: Option<RunningPipelineNode>,
4548
_result_tx: Sender<PipelineOutput>,
4649
) -> DaftResult<()> {
47-
todo!("Implement collect execution sloop");
50+
todo!("FLOTILLA_MS1: Implement collect execution sloop");
4851
}
4952
}
5053

src/daft-distributed/src/pipeline_node/limit.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub(crate) struct LimitNode {
2020
}
2121

2222
impl LimitNode {
23+
#[allow(dead_code)]
2324
pub fn new(
2425
limit: usize,
2526
local_physical_plans: Vec<LocalPhysicalPlanRef>,
@@ -41,14 +42,15 @@ impl LimitNode {
4142
}
4243
}
4344

45+
#[allow(dead_code)]
4446
async fn execution_loop(
4547
_task_dispatcher_handle: TaskDispatcherHandle,
4648
_local_physical_plans: Vec<LocalPhysicalPlanRef>,
4749
_input_node: Option<RunningPipelineNode>,
4850
_input_psets: HashMap<String, Vec<PartitionRef>>,
4951
_result_tx: Sender<PipelineOutput>,
5052
) -> DaftResult<()> {
51-
todo!("Implement limit execution loop");
53+
todo!("FLOTILLA_MS1: Implement limit execution loop");
5254
}
5355
}
5456

src/daft-distributed/src/pipeline_node/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub(crate) trait DistributedPipelineNode: Send + Sync {
3030
fn name(&self) -> &'static str;
3131
#[allow(dead_code)]
3232
fn children(&self) -> Vec<&dyn DistributedPipelineNode>;
33+
#[allow(dead_code)]
3334
fn start(&mut self, stage_context: &mut StageContext) -> RunningPipelineNode;
3435
}
3536

@@ -39,6 +40,7 @@ pub(crate) struct RunningPipelineNode {
3940
}
4041

4142
impl RunningPipelineNode {
43+
#[allow(dead_code)]
4244
fn new(result_receiver: Receiver<PipelineOutput>) -> Self {
4345
Self { result_receiver }
4446
}
@@ -53,7 +55,7 @@ impl Stream for RunningPipelineNode {
5355
type Item = DaftResult<PipelineOutput>;
5456

5557
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
56-
todo!("Implement stream for running pipeline node");
58+
todo!("FLOTILLA_MS1: Implement stream for running pipeline node");
5759
}
5860
}
5961

@@ -64,6 +66,7 @@ pub(crate) enum PipelineOutput {
6466
Running(Box<dyn SwordfishTaskResultHandle>),
6567
}
6668

69+
#[allow(dead_code)]
6770
pub(crate) fn logical_plan_to_pipeline_node(
6871
plan: LogicalPlanRef,
6972
config: Arc<DaftExecutionConfig>,
@@ -100,7 +103,7 @@ pub(crate) fn logical_plan_to_pipeline_node(
100103
std::mem::take(&mut self.psets),
101104
))];
102105
// Here we will have to return a placeholder, essentially cutting off the plan
103-
todo!("Implement pipeline node boundary splitter for limit");
106+
todo!("FLOTILLA_MS1: Implement pipeline node boundary splitter for limit");
104107
}
105108
_ if is_root => {
106109
let input_nodes = std::mem::take(&mut self.current_nodes);

src/daft-distributed/src/pipeline_node/translate.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ pub(crate) fn translate_pipeline_plan_to_local_physical_plans(
88
_logical_plan: LogicalPlanRef,
99
_execution_config: &DaftExecutionConfig,
1010
) -> DaftResult<Vec<LocalPhysicalPlanRef>> {
11-
todo!("Implement translate pipeline plan to local physical plans");
11+
todo!("FLOTILLA_MS1: Implement translate pipeline plan to local physical plans");
1212
}

src/daft-distributed/src/plan/mod.rs

Lines changed: 20 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@ use std::{
66
};
77

88
use common_daft_config::DaftExecutionConfig;
9-
use common_error::{DaftError, DaftResult};
9+
use common_error::DaftResult;
1010
use common_partitioning::PartitionRef;
1111
use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef};
12-
use futures::{Stream, StreamExt};
12+
use futures::Stream;
1313

1414
use crate::{
15-
channel::{create_channel, Receiver, Sender},
15+
channel::{create_channel, Receiver},
1616
runtime::{get_or_init_runtime, JoinHandle},
1717
scheduling::worker::WorkerManagerFactory,
18-
stage::split_at_stage_boundary,
18+
stage::StagePlan,
1919
};
2020

2121
pub struct DistributedPhysicalPlan {
22-
remaining_logical_plan: Option<LogicalPlanRef>,
22+
#[allow(dead_code)]
23+
logical_plan: LogicalPlanRef,
2324
config: Arc<DaftExecutionConfig>,
2425
}
2526

@@ -29,74 +30,40 @@ impl DistributedPhysicalPlan {
2930
config: Arc<DaftExecutionConfig>,
3031
) -> DaftResult<Self> {
3132
let plan = builder.build();
32-
if !can_translate_logical_plan(&plan) {
33-
return Err(DaftError::InternalError(
34-
"Cannot run this physical plan on distributed swordfish yet".to_string(),
35-
));
36-
}
3733

3834
Ok(Self {
39-
remaining_logical_plan: Some(plan),
35+
logical_plan: plan,
4036
config,
4137
})
4238
}
4339

44-
async fn run_plan_loop(
45-
logical_plan: LogicalPlanRef,
46-
config: Arc<DaftExecutionConfig>,
47-
worker_manager_factory: Box<dyn WorkerManagerFactory>,
48-
psets: HashMap<String, Vec<PartitionRef>>,
49-
result_sender: Sender<PartitionRef>,
40+
async fn execute_stages(
41+
_stage_plan: StagePlan,
42+
_psets: HashMap<String, Vec<PartitionRef>>,
43+
_worker_manager_factory: Box<dyn WorkerManagerFactory>,
5044
) -> DaftResult<()> {
51-
let (stage, _remaining_plan) = split_at_stage_boundary(&logical_plan, &config)?;
52-
let mut running_stage = stage.run_stage(psets, worker_manager_factory)?;
53-
while let Some(result) = running_stage.next().await {
54-
if result_sender.send(result?).await.is_err() {
55-
break;
56-
}
57-
}
58-
todo!("Implement stage running loop");
59-
}
60-
61-
#[allow(dead_code)]
62-
fn update_plan(
63-
_plan: LogicalPlanRef,
64-
_results: Vec<PartitionRef>,
65-
) -> DaftResult<LogicalPlanRef> {
66-
// Update the logical plan with the results of the previous stage.
67-
// This is where the AQE magic happens.
68-
todo!("Implement plan updating and AQE");
45+
todo!("FLOTILLA_MS1: Implement execute stages");
6946
}
7047

7148
pub fn run_plan(
7249
&self,
7350
psets: HashMap<String, Vec<PartitionRef>>,
7451
worker_manager_factory: Box<dyn WorkerManagerFactory>,
75-
) -> PlanResult {
76-
let (result_sender, result_receiver) = create_channel(1);
52+
) -> DaftResult<PlanResult> {
53+
let (_result_sender, result_receiver) = create_channel(1);
7754
let runtime = get_or_init_runtime();
78-
let handle = runtime.spawn(Self::run_plan_loop(
79-
self.remaining_logical_plan
80-
.as_ref()
81-
.expect("Expected remaining logical plan")
82-
.clone(),
83-
self.config.clone(),
84-
worker_manager_factory,
85-
psets,
86-
result_sender,
87-
));
88-
PlanResult::new(handle, result_receiver)
55+
let stage_plan = StagePlan::from_logical_plan(self.logical_plan.clone())?;
56+
let handle = runtime.spawn(async move {
57+
Self::execute_stages(stage_plan, psets, worker_manager_factory).await
58+
});
59+
Ok(PlanResult::new(handle, result_receiver))
8960
}
9061

9162
pub fn execution_config(&self) -> &Arc<DaftExecutionConfig> {
9263
&self.config
9364
}
9465
}
9566

96-
fn can_translate_logical_plan(_plan: &LogicalPlanRef) -> bool {
97-
todo!("Implement logical plan translation check");
98-
}
99-
10067
// This is the output of a plan, a receiver to receive the results of the plan.
10168
// And the join handle to the task that runs the plan.
10269
pub struct PlanResult {
@@ -117,6 +84,6 @@ impl Stream for PlanResult {
11784
type Item = DaftResult<PartitionRef>;
11885

11986
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
120-
todo!("Implement stream for plan result");
87+
todo!("FLOTILLA_MS1: Implement stream for plan result");
12188
}
12289
}

src/daft-distributed/src/python/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ impl PyDistributedPhysicalPlan {
9595
);
9696
let part_stream = self
9797
.planner
98-
.run_plan(psets, Box::new(worker_manager_factory));
98+
.run_plan(psets, Box::new(worker_manager_factory))?;
9999
let part_stream = PythonPartitionRefStream {
100100
inner: Arc::new(Mutex::new(part_stream)),
101101
};

src/daft-distributed/src/python/ray/worker_manager.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ pub(crate) struct RayWorkerManager {
1717
task_locals: pyo3_async_runtimes::TaskLocals,
1818
}
1919

20+
// TODO(FLOTILLA_MS1): Make Ray worker manager live for the duration of the program
21+
// so that we don't have to recreate it on every stage.
2022
impl RayWorkerManager {
2123
pub fn try_new(
2224
daft_execution_config: Arc<DaftExecutionConfig>,

src/daft-distributed/src/scheduling/dispatcher.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl TaskDispatcher {
4343
_dispatcher: Self,
4444
_task_rx: Receiver<DispatchedTask>,
4545
) -> DaftResult<()> {
46-
todo!("Implement run dispatch loop");
46+
todo!("FLOTILLA_MS1: Implement run dispatch loop");
4747
}
4848
}
4949

0 commit comments

Comments
 (0)