Skip to content

Commit 302fc17

Browse files
authored
feat: Flotilla scheduler and dispatcher actors (#4375)
## Changes Made Implement scheduler + dispatcher actors for flotilla. Including unit tests for both. ## Related Issues <!-- Link to related GitHub issues, e.g., "Closes #123" --> ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 6d09e9e commit 302fc17

File tree

11 files changed

+918
-194
lines changed

11 files changed

+918
-194
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/daft-distributed/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ rand = {workspace = true}
1717
tokio = {workspace = true}
1818
tokio-stream = {workspace = true}
1919
tokio-util = {workspace = true}
20+
tracing = {workspace = true}
2021
uuid.workspace = true
2122

2223
[features]

src/daft-distributed/src/python/ray/task.rs

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,76 @@
1-
use std::{any::Any, collections::HashMap, sync::Arc};
1+
use std::{any::Any, collections::HashMap, future::Future, sync::Arc};
22

33
use common_error::DaftResult;
44
use common_partitioning::{Partition, PartitionRef};
55
use daft_local_plan::PyLocalPhysicalPlan;
66
use pyo3::{pyclass, pymethods, FromPyObject, PyObject, PyResult, Python};
77

8-
use crate::scheduling::task::{SwordfishTask, TaskResultHandle};
8+
use crate::scheduling::{
9+
task::{SwordfishTask, TaskResultHandle},
10+
worker::WorkerId,
11+
};
912

1013
/// TaskHandle that wraps a Python RaySwordfishTaskHandle
1114
#[allow(dead_code)]
1215
pub(crate) struct RayTaskResultHandle {
13-
/// The handle to the task
14-
handle: Option<PyObject>,
16+
/// The handle to the RaySwordfishTaskHandle
17+
handle: PyObject,
18+
/// The coroutine to await the result of the task
19+
coroutine: Option<PyObject>,
1520
/// The task locals, i.e. the asyncio event loop
1621
task_locals: Option<pyo3_async_runtimes::TaskLocals>,
22+
/// The worker id
23+
worker_id: WorkerId,
1724
}
1825

1926
impl RayTaskResultHandle {
2027
/// Create a new TaskHandle from a Python RaySwordfishTaskHandle
2128
#[allow(dead_code)]
22-
pub fn new(handle: PyObject, task_locals: pyo3_async_runtimes::TaskLocals) -> Self {
29+
pub fn new(
30+
handle: PyObject,
31+
coroutine: PyObject,
32+
task_locals: pyo3_async_runtimes::TaskLocals,
33+
worker_id: WorkerId,
34+
) -> Self {
2335
Self {
24-
handle: Some(handle),
36+
handle,
37+
coroutine: Some(coroutine),
2538
task_locals: Some(task_locals),
39+
worker_id,
2640
}
2741
}
2842
}
2943

3044
impl TaskResultHandle for RayTaskResultHandle {
3145
/// Get the result of the task, awaiting if necessary
32-
async fn get_result(&mut self) -> DaftResult<PartitionRef> {
33-
let handle = self
34-
.handle
35-
.take()
36-
.expect("Task handle should be present during get_result");
37-
let coroutine = Python::with_gil(|py| {
38-
let coroutine = handle
39-
.call_method0(py, pyo3::intern!(py, "get_result"))?
40-
.into_bound(py);
41-
pyo3_async_runtimes::tokio::into_future(coroutine)
42-
})?;
46+
fn get_result(&mut self) -> impl Future<Output = DaftResult<PartitionRef>> + Send + 'static {
47+
// Create a rust future that will await the coroutine
48+
let coroutine = self.coroutine.take().unwrap();
49+
let task_locals = self.task_locals.take().unwrap();
4350

44-
// await the rust future in the scope of the asyncio event loop
45-
let task_locals = self
46-
.task_locals
47-
.take()
48-
.expect("Task locals should be present during get_result");
49-
let materialized_result = pyo3_async_runtimes::tokio::scope(task_locals, coroutine).await?;
51+
let await_coroutine = async move {
52+
let result = Python::with_gil(|py| {
53+
pyo3_async_runtimes::tokio::into_future(coroutine.into_bound(py))
54+
})?
55+
.await?;
56+
DaftResult::Ok(result)
57+
};
5058

51-
let ray_part_ref =
52-
Python::with_gil(|py| materialized_result.extract::<RayPartitionRef>(py))?;
53-
Ok(Arc::new(ray_part_ref))
59+
async move {
60+
let materialized_result =
61+
pyo3_async_runtimes::tokio::scope(task_locals, await_coroutine).await?;
62+
let ray_part_ref =
63+
Python::with_gil(|py| materialized_result.extract::<RayPartitionRef>(py))?;
64+
let partition_ref = Arc::new(ray_part_ref) as PartitionRef;
65+
Ok(partition_ref)
66+
}
5467
}
55-
}
5668

57-
impl Drop for RayTaskResultHandle {
58-
fn drop(&mut self) {
59-
if let Some(handle) = self.handle.take() {
60-
Python::with_gil(|py| {
61-
handle
62-
.call_method0(py, "cancel")
63-
.expect("Failed to cancel ray task")
64-
});
65-
}
69+
fn cancel_callback(&mut self) -> DaftResult<()> {
70+
Python::with_gil(|py| {
71+
self.handle.call_method0(py, "cancel")?;
72+
Ok(())
73+
})
6674
}
6775
}
6876

src/daft-distributed/src/python/ray/worker.rs

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ use pyo3::prelude::*;
88

99
use super::{task::RayTaskResultHandle, RaySwordfishTask};
1010
use crate::scheduling::{
11-
task::{SwordfishTask, TaskDetails, TaskId},
11+
scheduler::SchedulableTask,
12+
task::{SwordfishTask, Task, TaskDetails, TaskId, TaskResultHandleAwaiter},
1213
worker::{Worker, WorkerId},
1314
};
1415

@@ -44,49 +45,54 @@ impl RaySwordfishWorker {
4445

4546
#[allow(dead_code)]
4647
impl RaySwordfishWorker {
47-
pub fn mark_task_finished(&self, task_id: TaskId) {
48+
pub fn mark_task_finished(&self, task_id: &TaskId) {
4849
self.active_task_details
4950
.lock()
5051
.expect("Active task ids should be present")
51-
.remove(&task_id);
52+
.remove(task_id);
5253
}
5354

5455
pub fn submit_tasks(
5556
&self,
56-
tasks: Vec<SwordfishTask>,
57+
tasks: Vec<SchedulableTask<SwordfishTask>>,
5758
py: Python<'_>,
5859
task_locals: &pyo3_async_runtimes::TaskLocals,
59-
) -> DaftResult<Vec<RayTaskResultHandle>> {
60-
let (task_details, ray_swordfish_tasks) = tasks
61-
.into_iter()
62-
.map(|task| (TaskDetails::from(&task), RaySwordfishTask::new(task)))
63-
.unzip::<_, _, Vec<_>, Vec<_>>();
64-
65-
let py_task_handles = self
66-
.actor_handle
67-
.call_method1(
60+
) -> DaftResult<Vec<TaskResultHandleAwaiter<RayTaskResultHandle>>> {
61+
let mut task_handles = Vec::with_capacity(tasks.len());
62+
for task in tasks {
63+
let (task, result_tx, cancel_token) = task.into_inner();
64+
let task_id = task.task_id().clone();
65+
let task_details = TaskDetails::from(&task);
66+
67+
let ray_swordfish_task = RaySwordfishTask::new(task);
68+
let py_task_handle = self.actor_handle.call_method1(
6869
py,
69-
pyo3::intern!(py, "submit_tasks"),
70-
(ray_swordfish_tasks,),
71-
)?
72-
.extract::<Vec<PyObject>>(py)?;
73-
74-
let task_handles = py_task_handles
75-
.into_iter()
76-
.map(|py_task_handle| {
77-
let task_locals = task_locals.clone_ref(py);
78-
RayTaskResultHandle::new(py_task_handle, task_locals)
79-
})
80-
.collect::<Vec<_>>();
81-
82-
self.active_task_details
83-
.lock()
84-
.expect("Active task ids should be present")
85-
.extend(
86-
task_details
87-
.into_iter()
88-
.map(|details| (details.id.clone(), details)),
70+
pyo3::intern!(py, "submit_task"),
71+
(ray_swordfish_task,),
72+
)?;
73+
let coroutine = py_task_handle.call_method0(py, pyo3::intern!(py, "get_result"))?;
74+
75+
self.active_task_details
76+
.lock()
77+
.expect("Active task details should be present")
78+
.insert(task_id.clone(), task_details);
79+
80+
let task_locals = task_locals.clone_ref(py);
81+
let ray_task_result_handle = RayTaskResultHandle::new(
82+
py_task_handle,
83+
coroutine,
84+
task_locals,
85+
self.worker_id.clone(),
86+
);
87+
let task_result_handle_awaiter = TaskResultHandleAwaiter::new(
88+
task_id,
89+
self.worker_id.clone(),
90+
ray_task_result_handle,
91+
result_tx,
92+
cancel_token,
8993
);
94+
task_handles.push(task_result_handle_awaiter);
95+
}
9096

9197
Ok(task_handles)
9298
}

src/daft-distributed/src/python/ray/worker_manager.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use pyo3::prelude::*;
55

66
use super::{task::RayTaskResultHandle, worker::RaySwordfishWorker};
77
use crate::scheduling::{
8-
task::{SwordfishTask, TaskId},
8+
scheduler::SchedulableTask,
9+
task::{SwordfishTask, TaskId, TaskResultHandleAwaiter},
910
worker::{Worker, WorkerId, WorkerManager},
1011
};
1112

@@ -43,11 +44,11 @@ impl WorkerManager for RayWorkerManager {
4344

4445
fn submit_tasks_to_workers(
4546
&self,
46-
total_tasks: usize,
47-
tasks_per_worker: HashMap<WorkerId, Vec<SwordfishTask>>,
48-
) -> DaftResult<Vec<RayTaskResultHandle>> {
47+
tasks_per_worker: HashMap<WorkerId, Vec<SchedulableTask<SwordfishTask>>>,
48+
) -> DaftResult<Vec<TaskResultHandleAwaiter<RayTaskResultHandle>>> {
4949
Python::with_gil(|py| {
50-
let mut task_result_handles = Vec::with_capacity(total_tasks);
50+
let mut task_result_handles =
51+
Vec::with_capacity(tasks_per_worker.values().map(|v| v.len()).sum());
5152
for (worker_id, tasks) in tasks_per_worker {
5253
let handles = self
5354
.ray_workers
@@ -64,9 +65,9 @@ impl WorkerManager for RayWorkerManager {
6465
&self.ray_workers
6566
}
6667

67-
fn mark_task_finished(&self, task_id: TaskId, worker_id: WorkerId) {
68+
fn mark_task_finished(&self, task_id: &TaskId, worker_id: &WorkerId) {
6869
self.ray_workers
69-
.get(&worker_id)
70+
.get(worker_id)
7071
.expect("Worker should be present in RayWorkerManager")
7172
.mark_task_finished(task_id);
7273
}

0 commit comments

Comments
 (0)