Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pool/spawn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,10 @@ impl<T> Clone for Remote<T> {
/// Note that implements of Runner assumes `Remote` is `Sync` and `Send`.
/// So we need to use assert trait to ensure the constraint at compile time
/// to avoid future breaks.
#[allow(dead_code)]
trait AssertSync: Sync {}
impl<T: Send> AssertSync for Remote<T> {}
#[allow(dead_code)]
trait AssertSend: Send {}
impl<T: Send> AssertSend for Remote<T> {}

Expand Down
152 changes: 144 additions & 8 deletions src/pool/worker.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.

use crate::pool::{Local, Runner};
use crate::queue::{Pop, TaskCell};
use crate::queue::{AcquireState, Pop, TaskCell};
use parking_lot_core::SpinWait;

pub(crate) struct WorkerThread<T, R> {
Expand All @@ -20,31 +20,52 @@ where
T: TaskCell + Send,
R: Runner<TaskCell = T>,
{
/// Pops a task from the queue.
/// Returns `Some((Pop<T>, AcquireState))` if a task is found, where `AcquireState` indicates
/// how the task was acquired (immediate, after spin, or after park).
#[inline]
fn pop(&mut self) -> Option<Pop<T>> {
fn pop(&mut self, retry_after_park: bool) -> Option<(Pop<T>, AcquireState)> {
// Wait some time before going to sleep, which is more expensive.
let mut spin = SpinWait::new();
let mut state = if retry_after_park {
AcquireState::AfterPark
} else {
AcquireState::Immediate
};
Comment on lines +30 to +34
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When retry_after_park is true, state is initialized to AfterPark, so if a task is popped immediately in this call it will still be labeled AfterPark even though no spinning/parking happened during this acquisition. Consider deriving AcquireState based on what happened in the current pop() call (or clarify the semantics in docs), rather than carrying state across iterations via retry_after_park.

Copilot uses AI. Check for mistakes.
loop {
if let Some(t) = self.local.pop() {
return Some(t);
return Some((t, state));
}
if !spin.spin() {
break;
}
if state == AcquireState::Immediate {
state = AcquireState::AfterSpin;
}
Comment on lines +42 to +44
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current transition if state == Immediate { state = AfterSpin }, AfterSpin can never be produced when retry_after_park is true (since state starts as AfterPark). If you want to capture “acquired after spinning” even after a previous park/spurious wake, update the state machine accordingly (e.g., track spinning separately or allow transitioning from AfterPark to AfterSpin when spinning occurs).

Copilot uses AI. Check for mistakes.
}
self.runner.pause(&mut self.local);
let t = self.local.pop_or_sleep();
self.runner.resume(&mut self.local);
t
t.map(|task| (task, AcquireState::AfterPark))
}

pub fn run(mut self) {
self.runner.start(&mut self.local);
let mut retry_after_park = false;
while !self.local.core().is_shutdown() {
let task = match self.pop() {
Some(t) => t,
None => continue,
let (mut task, acquire_state) = match self.pop(retry_after_park) {
Some(t) => {
retry_after_park = false;
t
}
None => {
retry_after_park = true;
continue;
}
};
let extras = task.task_cell.mut_extras();
extras.acquire_state = Some(acquire_state);
extras.task_source = Some(task.task_source);
self.runner.handle(&mut self.local, task.task_cell);
}
self.runner.end(&mut self.local);
Expand All @@ -59,7 +80,7 @@ mod tests {
use super::*;
use crate::pool::spawn::*;
use crate::pool::SchedConfig;
use crate::queue::QueueType;
use crate::queue::{AcquireState, Extras, QueueType, TaskCell, TaskSource};
use crate::task::callback;
use std::sync::atomic::AtomicUsize;
use std::sync::*;
Expand Down Expand Up @@ -116,6 +137,56 @@ mod tests {
}
}

struct InspectTask {
extras: Extras,
}

impl InspectTask {
fn new() -> Self {
InspectTask {
extras: Extras::single_level(),
}
}
}

impl TaskCell for InspectTask {
fn mut_extras(&mut self) -> &mut Extras {
&mut self.extras
}
}

enum Event {
Paused,
Handled(TaskSource, AcquireState),
}

struct InspectRunner {
tx: mpsc::Sender<Event>,
}

impl crate::pool::Runner for InspectRunner {
type TaskCell = InspectTask;

fn handle(
&mut self,
_local: &mut Local<Self::TaskCell>,
mut task_cell: Self::TaskCell,
) -> bool {
let extras = task_cell.mut_extras();
let task_source = extras.task_source().unwrap();
let acquire_state = extras.acquire_state().unwrap();
self.tx
.send(Event::Handled(task_source, acquire_state))
.unwrap();
true
}

fn pause(&mut self, _local: &mut Local<Self::TaskCell>) -> bool {
self.tx.send(Event::Paused).unwrap();
true
}
}

#[test]
fn test_hooks() {
let (tx, rx) = mpsc::channel();
Expand Down Expand Up @@ -151,4 +222,69 @@ mod tests {
expected_metrics.end = 1;
assert_eq!(expected_metrics, *metrics.lock().unwrap());
}

#[test]
fn test_worker_run_task_from_local_immediate() {
let mut config: SchedConfig = Default::default();
config.max_thread_count = 1;
config.core_thread_count = AtomicUsize::new(1);
let (remote, mut locals) = build_spawn(QueueType::SingleLevel, config);
let (tx, rx) = mpsc::channel();
let runner = InspectRunner { tx };

let mut local = locals.remove(0);
local.spawn(InspectTask::new()); // spawn a local task before worker starts
let th = WorkerThread::new(local, runner);
let handle = std::thread::spawn(move || {
th.run();
});

match rx.recv_timeout(Duration::from_secs(1)).unwrap() {
Event::Handled(task_source, acquire_state) => {
assert_eq!(task_source, TaskSource::LocalQueue);
assert_eq!(acquire_state, AcquireState::Immediate);
}
Event::Paused => panic!("did not expect pause before handling task"),
}

remote.stop();
handle.join().unwrap();
}

#[test]
fn test_worker_run_task_from_global_after_park() {
let mut config: SchedConfig = Default::default();
config.max_thread_count = 1;
config.core_thread_count = AtomicUsize::new(1);
let (remote, mut locals) = build_spawn(QueueType::SingleLevel, config);
let (tx, rx) = mpsc::channel();
let runner = InspectRunner { tx };

let th = WorkerThread::new(locals.remove(0), runner);
let handle = std::thread::spawn(move || {
th.run();
});

match rx.recv_timeout(Duration::from_secs(1)).unwrap() {
Event::Paused => {}
Event::Handled(_, _) => panic!("expected pause before handling task"),
}

remote.spawn(InspectTask::new());

let deadline = Instant::now() + Duration::from_secs(1);
let (task_source, acquire_state) = loop {
let timeout = deadline.saturating_duration_since(Instant::now());
let event = rx.recv_timeout(timeout).unwrap();
if let Event::Handled(task_source, acquire_state) = event {
break (task_source, acquire_state);
}
};

assert_eq!(task_source, TaskSource::GlobalQueue);
assert_eq!(acquire_state, AcquireState::AfterPark);

remote.stop();
handle.join().unwrap();
}
}
7 changes: 3 additions & 4 deletions src/queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub mod priority;
mod extras;
mod single_level;

pub use self::extras::Extras;
pub use self::extras::{AcquireState, Extras, TaskSource};

use std::time::Instant;

Expand Down Expand Up @@ -74,9 +74,8 @@ pub struct Pop<T> {
/// When the task was pushed to the queue.
pub schedule_time: Instant,

/// Whether the task comes from the current [`LocalQueue`] instead of being
/// just stolen from the injector or other [`LocalQueue`]s.
pub from_local: bool,
/// The source of the task, indicating where the task comes from.
pub task_source: TaskSource,
}

Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pop is a public struct; replacing the public from_local: bool field with task_source: TaskSource is a breaking API change for downstream users. If backward compatibility is needed, consider keeping from_local (possibly deprecated) or adding a from_local() accessor/compat shim while introducing task_source.

Suggested change
impl<T> Pop<T> {
/// Returns whether this task was popped from a local queue.
///
/// This provides a compatibility shim for the former `from_local: bool`
/// field, using the newer `task_source` information instead.
#[deprecated(
since = "0.0.0",
note = "use `task_source` instead; this accessor is kept for backward compatibility"
)]
pub fn from_local(&self) -> bool {
self.task_source.is_local()
}
}

Copilot uses AI. Check for mistakes.
/// The local queue of a task queue.
Expand Down
42 changes: 42 additions & 0 deletions src/queue/extras.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@ use rand::prelude::*;
use std::sync::Arc;
use std::time::{Duration, Instant};

/// The source of a task, indicating where the task comes from when popped.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskSource {
/// Task popped from the local queue (most efficient path).
LocalQueue,
/// Task popped from the global injector queue.
GlobalQueue,
/// Task stolen from another worker's local queue.
OtherWorker,
}

/// Indicates how the worker acquired the task.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AcquireState {
/// Task was popped immediately from the local queue without waiting.
Immediate,
/// Task was acquired after the worker thread spun (busy-waited).
AfterSpin,
/// Task was acquired after the worker thread was parked (slept).
AfterPark,
}

/// The extras for the task cells pushed into a queue.
#[derive(Debug, Clone)]
pub struct Extras {
Expand All @@ -32,6 +54,12 @@ pub struct Extras {
/// Extra metadata of this task. User can use this field to store arbitrary data. It is useful
/// in some case to implement more complext `TaskPriorityProvider` in the priority task queue.
pub(crate) metadata: Vec<u8>,
/// The source of the task, indicating where the task comes from when popped.
/// This field is set when the task is popped from the queue.
pub(crate) task_source: Option<TaskSource>,
/// Indicates how the worker acquired the task.
/// This field is set when the task is popped from the queue.
pub(crate) acquire_state: Option<AcquireState>,
}

impl Extras {
Expand All @@ -48,6 +76,8 @@ impl Extras {
fixed_level: None,
exec_times: 0,
metadata: Vec::new(),
task_source: None,
acquire_state: None,
}
}

Expand All @@ -71,6 +101,8 @@ impl Extras {
fixed_level,
exec_times: 0,
metadata: Vec::new(),
task_source: None,
acquire_state: None,
}
}

Expand Down Expand Up @@ -110,4 +142,14 @@ impl Extras {
pub fn set_metadata(&mut self, metadata: Vec<u8>) {
self.metadata = metadata;
}

/// Gets the source of the task.
pub fn task_source(&self) -> Option<TaskSource> {
self.task_source
}

/// Gets how the worker acquired the task.
pub fn acquire_state(&self) -> Option<AcquireState> {
self.acquire_state
}
}
36 changes: 29 additions & 7 deletions src/queue/multilevel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,20 @@ where
}

pub(super) fn pop(&mut self) -> Option<Pop<T>> {
fn into_pop<T>(mut t: T, from_local: bool) -> Pop<T>
fn into_pop<T>(mut t: T, task_source: super::TaskSource) -> Pop<T>
where
T: TaskCell,
{
let schedule_time = t.mut_extras().schedule_time.unwrap();
Pop {
task_cell: t,
schedule_time,
from_local,
task_source,
}
}

if let Some(t) = self.local_queue.pop() {
return Some(into_pop(t, true));
return Some(into_pop(t, super::TaskSource::LocalQueue));
}
let mut rng = thread_rng();
let mut need_retry = true;
Expand All @@ -136,7 +136,9 @@ where
.unwrap_or(LEVEL_NUM - 1)
};
match self.steal_from_injector(expected_level) {
Steal::Success(t) => return Some(into_pop(t, false)),
Steal::Success(t) => {
return Some(into_pop(t, super::TaskSource::GlobalQueue));
}
Steal::Retry => need_retry = true,
_ => {}
}
Expand All @@ -145,7 +147,7 @@ where
for (idx, stealer) in self.stealers.iter().enumerate() {
match stealer.steal_batch_and_pop(&self.local_queue) {
Steal::Success(t) => {
found = Some((idx, into_pop(t, false)));
found = Some((idx, into_pop(t, super::TaskSource::OtherWorker)));
break;
}
Steal::Retry => need_retry = true,
Expand All @@ -160,7 +162,9 @@ where
}
for l in expected_level + 1..expected_level + LEVEL_NUM {
match self.steal_from_injector(l % LEVEL_NUM) {
Steal::Success(t) => return Some(into_pop(t, false)),
Steal::Success(t) => {
return Some(into_pop(t, super::TaskSource::GlobalQueue));
}
Steal::Retry => need_retry = true,
_ => {}
}
Expand Down Expand Up @@ -920,7 +924,7 @@ pub(super) fn recent() -> Instant {
mod tests {
use super::*;
use crate::pool::build_spawn;
use crate::queue::Extras;
use crate::queue::{Extras, TaskSource};

use std::sync::atomic::AtomicU64;
use std::sync::mpsc;
Expand Down Expand Up @@ -1027,6 +1031,24 @@ mod tests {
assert!(schedule_time.elapsed() >= SLEEP_DUR);
}

#[test]
fn test_task_source() {
let builder = Builder::new(Config::default());
let (injector, mut locals) = builder.build(2);

locals[0].push(MockTask::new(0, Extras::multilevel_default()));
let pop = locals[0].pop().unwrap();
assert_eq!(pop.task_source, TaskSource::LocalQueue);

injector.push(MockTask::new(0, Extras::multilevel_default()));
let pop = locals[0].pop().unwrap();
assert_eq!(pop.task_source, TaskSource::GlobalQueue);

locals[0].push(MockTask::new(0, Extras::multilevel_default()));
let pop = locals[1].pop().unwrap();
assert_eq!(pop.task_source, TaskSource::OtherWorker);
}

#[test]
fn test_push_task() {
let builder = Builder::new(
Expand Down
Loading
Loading