Skip to content

Commit faf050c

Browse files
colin-hoColin Ho
andauthored
[FEAT] Respect resource request for projections in swordfish (Eventual-Inc#3460)
Make swordfish respect resource requests (CPU only for now) on projections. --------- Co-authored-by: Colin Ho <[email protected]>
1 parent 56f5089 commit faf050c

File tree

5 files changed

+54
-16
lines changed

5 files changed

+54
-16
lines changed

src/common/resource-request/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ impl ResourceRequest {
139139
self.memory_bytes.map(|x| x * (factor as usize)),
140140
)
141141
}
142+
143+
pub fn num_cpus(&self) -> Option<f64> {
144+
self.num_cpus
145+
}
146+
147+
pub fn num_gpus(&self) -> Option<f64> {
148+
self.num_gpus
149+
}
150+
151+
pub fn memory_bytes(&self) -> Option<usize> {
152+
self.memory_bytes
153+
}
142154
}
143155

144156
impl Add for &ResourceRequest {

src/daft-local-execution/src/intermediate_ops/actor_pool_project.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ impl IntermediateOperator for ActorPoolProjectOperator {
181181
}))
182182
}
183183

184-
fn max_concurrency(&self) -> usize {
185-
self.concurrency
184+
fn max_concurrency(&self) -> DaftResult<usize> {
185+
Ok(self.concurrency)
186186
}
187187

188188
fn dispatch_spawner(

src/daft-local-execution/src/intermediate_ops/intermediate_op.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use common_display::tree::TreeDisplay;
44
use common_error::DaftResult;
55
use common_runtime::{get_compute_runtime, RuntimeRef};
66
use daft_micropartition::MicroPartition;
7+
use snafu::ResultExt;
78
use tracing::{info_span, instrument};
89

910
use crate::{
@@ -14,7 +15,7 @@ use crate::{
1415
dispatcher::{DispatchSpawner, RoundRobinDispatcher, UnorderedDispatcher},
1516
pipeline::PipelineNode,
1617
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
17-
ExecutionRuntimeContext, OperatorOutput, NUM_CPUS,
18+
ExecutionRuntimeContext, OperatorOutput, PipelineExecutionSnafu, NUM_CPUS,
1819
};
1920

2021
pub(crate) trait IntermediateOpState: Send + Sync {
@@ -49,8 +50,9 @@ pub trait IntermediateOperator: Send + Sync {
4950
}
5051
/// The maximum number of concurrent workers that can be spawned for this operator.
5152
/// Each worker will has its own IntermediateOperatorState.
52-
fn max_concurrency(&self) -> usize {
53-
*NUM_CPUS
53+
/// This method should be overridden if the operator needs to limit the number of concurrent workers, i.e. UDFs with resource requests.
54+
fn max_concurrency(&self) -> DaftResult<usize> {
55+
Ok(*NUM_CPUS)
5456
}
5557

5658
fn dispatch_spawner(
@@ -208,7 +210,9 @@ impl PipelineNode for IntermediateNode {
208210
));
209211
}
210212
let op = self.intermediate_op.clone();
211-
let num_workers = op.max_concurrency();
213+
let num_workers = op.max_concurrency().context(PipelineExecutionSnafu {
214+
node_name: self.name(),
215+
})?;
212216
let (destination_sender, destination_receiver) = create_channel(1);
213217
let counting_sender = CountingSender::new(destination_sender, self.runtime_stats.clone());
214218

src/daft-local-execution/src/intermediate_ops/project.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use std::sync::Arc;
22

3+
use common_error::{DaftError, DaftResult};
34
use common_runtime::RuntimeRef;
4-
use daft_dsl::ExprRef;
5+
use daft_dsl::{functions::python::get_resource_request, ExprRef};
56
use daft_micropartition::MicroPartition;
67
use tracing::instrument;
78

89
use super::intermediate_op::{
910
IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator,
1011
IntermediateOperatorResult,
1112
};
13+
use crate::NUM_CPUS;
1214

1315
pub struct ProjectOperator {
1416
projection: Arc<Vec<ExprRef>>,
@@ -45,4 +47,28 @@ impl IntermediateOperator for ProjectOperator {
4547
fn name(&self) -> &'static str {
4648
"ProjectOperator"
4749
}
50+
51+
fn max_concurrency(&self) -> DaftResult<usize> {
52+
let resource_request = get_resource_request(&self.projection);
53+
match resource_request {
54+
// If the resource request specifies a number of CPUs, the max concurrency is the number of CPUs
55+
// divided by the requested number of CPUs, clamped to (1, NUM_CPUS).
56+
// E.g. if the resource request specifies 2 CPUs and NUM_CPUS is 4, the max concurrency is 2.
57+
Some(resource_request) if resource_request.num_cpus().is_some() => {
58+
let requested_num_cpus = resource_request.num_cpus().unwrap();
59+
if requested_num_cpus > *NUM_CPUS as f64 {
60+
Err(DaftError::ValueError(format!(
61+
"Requested {} CPUs but found only {} available",
62+
requested_num_cpus, *NUM_CPUS
63+
)))
64+
} else {
65+
Ok(
66+
(*NUM_CPUS as f64 / requested_num_cpus).clamp(1.0, *NUM_CPUS as f64)
67+
as usize,
68+
)
69+
}
70+
}
71+
_ => Ok(*NUM_CPUS),
72+
}
73+
}
4874
}

tests/test_resource_requests.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@
1313
from daft.internal.gpu import cuda_visible_devices
1414
from tests.conftest import get_tests_daft_runner_name
1515

16-
pytestmark = pytest.mark.skipif(
17-
get_tests_daft_runner_name() == "native",
18-
reason="Native runner does not support resource requests",
19-
)
20-
2116

2217
def no_gpu_available() -> bool:
2318
return len(cuda_visible_devices()) == 0
@@ -81,18 +76,19 @@ def test_resource_request_pickle_roundtrip():
8176
###
8277

8378

84-
@pytest.mark.skipif(get_tests_daft_runner_name() not in {"py"}, reason="requires PyRunner to be in use")
79+
@pytest.mark.skipif(
80+
get_tests_daft_runner_name() not in {"native", "py"}, reason="requires Native or Py Runner to be in use"
81+
)
8582
def test_requesting_too_many_cpus():
8683
df = daft.from_pydict(DATA)
87-
system_info = SystemInfo()
8884

89-
my_udf_parametrized = my_udf.override_options(num_cpus=system_info.cpu_count() + 1)
85+
my_udf_parametrized = my_udf.override_options(num_cpus=1000)
9086
df = df.with_column(
9187
"foo",
9288
my_udf_parametrized(col("id")),
9389
)
9490

95-
with pytest.raises(RuntimeError):
91+
with pytest.raises(Exception):
9692
df.collect()
9793

9894

0 commit comments

Comments
 (0)