From ca2bf0558c38a770e7de37d68dcec77b9433edca Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 4 Nov 2025 09:01:24 -0800 Subject: [PATCH 1/6] Make full_thread_state a shared_ptr Signed-off-by: Alessandro Bellina --- src/main/cpp/src/SparkResourceAdaptorJni.cpp | 386 +++++++++---------- 1 file changed, 175 insertions(+), 211 deletions(-) diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 63a2c43807..292f0aa163 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -680,41 +680,35 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (shutting_down) { throw std::runtime_error("spark_resource_adaptor is shutting down"); } auto const found = threads.find(thread_id); if (found != threads.end()) { - if (found->second.task_id >= 0 && found->second.task_id != task_id) { - LOG_STATUS("FIXUP", - thread_id, - found->second.task_id, - found->second.state, - "desired task_id {}", - task_id); - remove_thread_association(thread_id, found->second.task_id, lock); + if (found->second->task_id >= 0 && found->second->task_id != task_id) { + LOG_STATUS("FIXUP", thread_id, found->second->task_id, found->second->state, + "desired task_id {}", task_id); + remove_thread_association(thread_id, found->second->task_id, lock); } } auto const was_threads_inserted = threads.emplace( - thread_id, full_thread_state(thread_state::THREAD_RUNNING, thread_id, task_id)); + thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id, task_id)); if (was_threads_inserted.second == false) { - if (was_threads_inserted.first->second.state == thread_state::THREAD_REMOVE_THROW) { + if (was_threads_inserted.first->second->state == thread_state::THREAD_REMOVE_THROW) { std::stringstream ss; - ss << "A thread " << thread_id << " is shutting down " - << was_threads_inserted.first->second.task_id << " vs " << task_id; + ss << "A thread " << thread_id << " is shutting down " + << was_threads_inserted.first->second->task_id << " vs " << task_id; auto const msg = ss.str(); - LOG_STATUS("ERROR", - thread_id, - was_threads_inserted.first->second.task_id, - was_threads_inserted.first->second.state, - msg); + LOG_STATUS("ERROR", + thread_id, was_threads_inserted.first->second->task_id, was_threads_inserted.first->second->state, + msg); throw std::invalid_argument(msg); } - if (was_threads_inserted.first->second.task_id != task_id) { + if (was_threads_inserted.first->second->task_id != task_id) { std::stringstream ss; ss << "A thread " << thread_id << " can only be dedicated to a single task." - << was_threads_inserted.first->second.task_id << " != " << task_id; + << was_threads_inserted.first->second->task_id << " != " << task_id; auto const msg = ss.str(); LOG_STATUS("ERROR", thread_id, - was_threads_inserted.first->second.task_id, - was_threads_inserted.first->second.state, + was_threads_inserted.first->second->task_id, + was_threads_inserted.first->second->state, msg); throw std::invalid_argument(msg); } @@ -742,21 +736,21 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); - if (thread != threads.end()) { thread->second.reset_retry_state(true); } + if (thread != threads.end()) { thread->second->reset_retry_state(true); } } void end_retry_block(long const thread_id) { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); - if (thread != threads.end()) { thread->second.reset_retry_state(false); } + if (thread != threads.end()) { thread->second->reset_retry_state(false); } } bool is_working_on_task_as_pool_thread(long const thread_id) { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); - if (thread != threads.end()) { return !thread->second.pool_task_ids.empty(); } + if (thread != threads.end()) { return !thread->second->pool_task_ids.empty(); } return false; } @@ -775,15 +769,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (shutting_down) { throw std::runtime_error("spark_resource_adaptor is shutting down"); } auto const was_inserted = - threads.emplace(thread_id, full_thread_state(thread_state::THREAD_RUNNING, thread_id)); + threads.emplace(thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id)); if (was_inserted.second == true) { - was_inserted.first->second.is_for_shuffle = is_for_shuffle; + was_inserted.first->second->is_for_shuffle = is_for_shuffle; LOG_TRANSITION(thread_id, -1, thread_state::UNKNOWN, thread_state::THREAD_RUNNING); - } else if (was_inserted.first->second.task_id != -1) { + } else if (was_inserted.first->second->task_id != -1) { throw std::invalid_argument("the thread is associated with a non-pool task already"); - } else if (was_inserted.first->second.state == thread_state::THREAD_REMOVE_THROW) { + } else if (was_inserted.first->second->state == thread_state::THREAD_REMOVE_THROW) { throw std::invalid_argument("the thread is in the process of shutting down."); - } else if (was_inserted.first->second.is_for_shuffle != is_for_shuffle) { + } else if (was_inserted.first->second->is_for_shuffle != is_for_shuffle) { if (is_for_shuffle) { throw std::invalid_argument( "the thread is marked as a non-shuffle thread, and we cannot change it while there are " @@ -798,13 +792,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // save the metrics for all tasks before we add any new ones. checkpoint_metrics(was_inserted.first->second); - was_inserted.first->second.pool_task_ids.insert(task_ids.begin(), task_ids.end()); - LOG_STATUS_CONTAINER("ADD_TASKS", - thread_id, - -1, - was_inserted.first->second.state, - "CURRENT IDs", - was_inserted.first->second.pool_task_ids); + was_inserted.first->second->pool_task_ids.insert(task_ids.begin(), task_ids.end()); + LOG_STATUS_CONTAINER("ADD_TASKS", thread_id, -1, was_inserted.first->second->state, "CURRENT IDs", was_inserted.first->second->pool_task_ids); } void pool_thread_finished_for_tasks(long const thread_id, @@ -820,15 +809,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // Now drop the tasks from the pool for (auto const& id : task_ids) { - thread->second.pool_task_ids.erase(id); + thread->second->pool_task_ids.erase(id); } - LOG_STATUS_CONTAINER("REMOVE_TASKS", - thread_id, - -1, - thread->second.state, - "CURRENT IDs", - thread->second.pool_task_ids); - if (thread->second.pool_task_ids.empty()) { + LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, thread->second->state, "CURRENT IDs", thread->second->pool_task_ids); + if (thread->second->pool_task_ids.empty()) { if (remove_thread_association(thread_id, -1, lock)) { wake_up_threads_after_task_finishes(lock); } @@ -876,14 +860,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const& thread_id : thread_ids) { auto const thread = threads.find(thread_id); if (thread != threads.end()) { - if (thread->second.pool_task_ids.erase(task_id) != 0) { - LOG_STATUS_CONTAINER("REMOVE_TASKS", - thread_id, - -1, - thread->second.state, - "CURRENT IDs", - thread->second.pool_task_ids); - if (thread->second.pool_task_ids.empty()) { + if (thread->second->pool_task_ids.erase(task_id) != 0) { + LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, thread->second->state, "CURRENT IDs", thread->second->pool_task_ids); + if (thread->second->pool_task_ids.empty()) { run_checks = remove_thread_association(thread_id, task_id, lock) || run_checks; } } @@ -964,7 +943,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - threads_at->second.retry_oom.init(num_ooms, skip_count, oom_filter); + threads_at->second->retry_oom.init(num_ooms, skip_count, oom_filter); } else { throw std::invalid_argument("the thread is not associated with any task/shuffle"); } @@ -982,7 +961,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - threads_at->second.split_and_retry_oom.init(num_ooms, skip_count, oom_filter); + threads_at->second->split_and_retry_oom.init(num_ooms, skip_count, oom_filter); } else { throw std::invalid_argument("the thread is not associated with any task/shuffle"); } @@ -997,7 +976,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - threads_at->second.cudf_exception_injected = num_times; + threads_at->second->cudf_exception_injected = num_times; } else { throw std::invalid_argument("the thread is not associated with any task/shuffle"); } @@ -1018,8 +997,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const thread_id : task_at->second) { auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - ret += (threads_at->second.metrics.*MetricPtr); - (threads_at->second.metrics.*MetricPtr) = 0; + ret += (threads_at->second->metrics.*MetricPtr); + (threads_at->second->metrics.*MetricPtr) = 0; } } } @@ -1041,7 +1020,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (task_at != task_to_threads.end()) { for (auto const thread_id : task_at->second) { auto const threads_at = threads.find(thread_id); - if (threads_at != threads.end()) { ret += (threads_at->second.metrics.*MetricPtr); } + if (threads_at != threads.end()) { ret += (threads_at->second->metrics.*MetricPtr); } } } @@ -1109,8 +1088,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const thread_id : task_at->second) { auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - ret += threads_at->second.currently_blocked_for(); - ret += threads_at->second.metrics.time_lost_or_blocked; + ret += threads_at->second->currently_blocked_for(); + ret += threads_at->second->metrics.time_lost_or_blocked; } } } @@ -1165,7 +1144,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const tid = static_cast(pthread_self()); auto const thread = threads.find(tid); - if (thread != threads.end()) { thread->second.is_in_spilling = true; } + if (thread != threads.end()) { thread->second->is_in_spilling = true; } } void spill_range_done() @@ -1173,7 +1152,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const tid = static_cast(pthread_self()); auto const thread = threads.find(tid); - if (thread != threads.end()) { thread->second.is_in_spilling = false; } + if (thread != threads.end()) { thread->second->is_in_spilling = false; } } /** @@ -1198,7 +1177,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - return static_cast(threads_at->second.state); + return static_cast(threads_at->second->state); } else { return -1; } @@ -1212,7 +1191,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // from an operation. std::mutex state_mutex; std::condition_variable task_has_woken_condition; - std::map threads; + std::map> threads; std::map> task_to_threads; long gpu_memory_allocated_bytes = 0; @@ -1231,11 +1210,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { * of setting the state directly. This will log the transition and do a little bit of * verification. */ - void transition(full_thread_state& state, thread_state const new_state) + void transition(std::shared_ptr state, + thread_state const new_state) { - thread_state original = state.state; - state.transition_to(new_state); - LOG_TRANSITION(state.thread_id, state.task_id, original, new_state); + thread_state original = state->state; + state->transition_to(new_state); + LOG_TRANSITION(state->thread_id, state->task_id, original, new_state); } /** @@ -1252,7 +1232,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); long task_id = -1; - if (thread != threads.end()) { task_id = thread->second.task_id; } + if (thread != threads.end()) { task_id = thread->second->task_id; } if (task_id < 0) { std::stringstream ss; @@ -1260,24 +1240,24 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { throw std::invalid_argument(ss.str()); } - thread->second.pool_blocked = pool_blocked; + thread->second->pool_blocked = pool_blocked; } /** * Checkpoint all of the metrics for a thread. */ - void checkpoint_metrics(full_thread_state& state) + void checkpoint_metrics(std::shared_ptr state) { - if (state.task_id < 0) { + if (state->task_id < 0) { // save the metrics for all tasks before we add any new ones. - for (auto const task_id : state.pool_task_ids) { + for (auto const task_id : state->pool_task_ids) { auto const metrics_at = task_to_metrics.try_emplace(task_id, task_metrics()); - metrics_at.first->second.add(state.metrics); + metrics_at.first->second.add(state->metrics); } - state.metrics.clear(); + state->metrics.clear(); } else { - auto const metrics_at = task_to_metrics.try_emplace(state.task_id, task_metrics()); - metrics_at.first->second.take_from(state.metrics); + auto const metrics_at = task_to_metrics.try_emplace(state->task_id, task_metrics()); + metrics_at.first->second.take_from(state->metrics); } } @@ -1350,23 +1330,23 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { while (!done) { auto thread = threads.find(thread_id); if (thread != threads.end()) { - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_BLOCKED: // fall through case thread_state::THREAD_BUFN: - LOG_STATUS("WAITING", thread_id, thread->second.task_id, thread->second.state); - thread->second.before_block(); + LOG_STATUS("WAITING", thread_id, thread->second->task_id, thread->second->state); + thread->second->before_block(); do { - thread->second.wake_condition->wait(lock); + thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); - } while (thread != threads.end() && is_blocked(thread->second.state)); - thread->second.after_block(); + } while (thread != threads.end() && is_blocked(thread->second->state)); + thread->second->after_block(); task_has_woken_condition.notify_all(); break; case thread_state::THREAD_BUFN_THROW: transition(thread->second, thread_state::THREAD_BUFN_WAIT); - thread->second.record_failed_retry_time(); - throw_retry_oom("rollback and retry operation", thread->second, lock); + thread->second->record_failed_retry_time(); + throw_retry_oom("rollback and retry operation", *thread->second, lock); break; case thread_state::THREAD_BUFN_WAIT: transition(thread->second, thread_state::THREAD_BUFN); @@ -1376,33 +1356,33 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { check_and_update_for_bufn(lock); // If that caused us to transition to a new state, then we need to adjust to it // appropriately... - if (is_blocked(thread->second.state)) { - LOG_STATUS("WAITING", thread_id, thread->second.task_id, thread->second.state); - thread->second.before_block(); + if (is_blocked(thread->second->state)) { + LOG_STATUS("WAITING", thread_id, thread->second->task_id, thread->second->state); + thread->second->before_block(); do { - thread->second.wake_condition->wait(lock); + thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); - } while (thread != threads.end() && is_blocked(thread->second.state)); - thread->second.after_block(); + } while (thread != threads.end() && is_blocked(thread->second->state)); + thread->second->after_block(); task_has_woken_condition.notify_all(); } break; case thread_state::THREAD_SPLIT_THROW: transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.record_failed_retry_time(); + thread->second->record_failed_retry_time(); throw_split_and_retry_oom( - "rollback, split input, and retry operation", thread->second, lock); + "rollback, split input, and retry operation", *thread->second, lock); break; case thread_state::THREAD_REMOVE_THROW: LOG_TRANSITION( - thread_id, thread->second.task_id, thread->second.state, thread_state::UNKNOWN); + thread_id, thread->second->task_id, thread->second->state, thread_state::UNKNOWN); // don't need to record failed time metric the thread is already gone... threads.erase(thread); task_has_woken_condition.notify_all(); throw std::runtime_error("thread removed while blocked"); default: if (!first_time) { - LOG_STATUS("DONE WAITING", thread_id, thread->second.task_id, thread->second.state); + LOG_STATUS("DONE WAITING", thread_id, thread->second->task_id, thread->second->state); } done = true; } @@ -1425,10 +1405,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { { bool are_any_tasks_just_blocked = false; for (auto& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BLOCKED: transition(t_state, thread_state::THREAD_RUNNING); - t_state.wake_condition->notify_all(); + t_state->wake_condition->notify_all(); are_any_tasks_just_blocked = true; break; default: break; @@ -1438,14 +1418,14 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (!are_any_tasks_just_blocked) { // wake up all of the BUFN tasks. for (auto& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BUFN: // fall through case thread_state::THREAD_BUFN_THROW: // fall through case thread_state::THREAD_BUFN_WAIT: transition(t_state, thread_state::THREAD_RUNNING); - t_state.wake_condition->notify_all(); + t_state->wake_condition->notify_all(); break; default: break; } @@ -1472,12 +1452,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (remove_task_id < 0) { thread_should_be_removed = true; } else { - auto const task_id = threads_at->second.task_id; + auto const task_id = threads_at->second->task_id; if (task_id >= 0) { if (task_id == remove_task_id) { thread_should_be_removed = true; } } else { - threads_at->second.pool_task_ids.erase(remove_task_id); - if (threads_at->second.pool_task_ids.empty()) { thread_should_be_removed = true; } + threads_at->second->pool_task_ids.erase(remove_task_id); + if (threads_at->second->pool_task_ids.empty()) { thread_should_be_removed = true; } } } @@ -1492,20 +1472,20 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (task_at != task_to_threads.end()) { task_at->second.erase(thread_id); } } - switch (threads_at->second.state) { + switch (threads_at->second->state) { case thread_state::THREAD_BLOCKED: // fall through case thread_state::THREAD_BUFN: transition(threads_at->second, thread_state::THREAD_REMOVE_THROW); - threads_at->second.wake_condition->notify_all(); + threads_at->second->wake_condition->notify_all(); break; case thread_state::THREAD_RUNNING: ret = true; // fall through; default: LOG_TRANSITION(thread_id, - threads_at->second.task_id, - threads_at->second.state, + threads_at->second->task_id, + threads_at->second->state, thread_state::UNKNOWN); threads.erase(threads_at); } @@ -1546,7 +1526,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { { auto const thread = threads.find(thread_id); if (thread != threads.end()) { - switch (thread->second.state) { + switch (thread->second->state) { // If the thread is in one of the ALLOC or ALLOC_FREE states, we have detected a loop // likely due to spill setup required in cuDF. We will treat this allocation differently // and skip transitions. @@ -1559,7 +1539,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::stringstream ss; ss << "thread " << thread_id << " is trying to do a blocking allocate while already in the state " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } @@ -1568,39 +1548,39 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { default: break; } - if (thread->second.retry_oom.matches(is_for_cpu)) { - if (thread->second.retry_oom.skip_count > 0) { - thread->second.retry_oom.skip_count--; - } else if (thread->second.retry_oom.hit_count > 0) { - thread->second.retry_oom.hit_count--; - thread->second.metrics.num_times_retry_throw++; + if (thread->second->retry_oom.matches(is_for_cpu)) { + if (thread->second->retry_oom.skip_count > 0) { + thread->second->retry_oom.skip_count--; + } else if (thread->second->retry_oom.hit_count > 0) { + thread->second->retry_oom.hit_count--; + thread->second->metrics.num_times_retry_throw++; std::string const op_prefix = "INJECTED_RETRY_OOM_"; std::string const op = op_prefix + (is_for_cpu ? "CPU" : "GPU"); - LOG_STATUS(op, thread_id, thread->second.task_id, thread->second.state); - thread->second.record_failed_retry_time(); + LOG_STATUS(op, thread_id, thread->second->task_id, thread->second->state); + thread->second->record_failed_retry_time(); throw_java_exception(is_for_cpu ? CPU_RETRY_OOM_CLASS : GPU_RETRY_OOM_CLASS, "injected RetryOOM"); } } - if (thread->second.cudf_exception_injected > 0) { - thread->second.cudf_exception_injected--; + if (thread->second->cudf_exception_injected > 0) { + thread->second->cudf_exception_injected--; LOG_STATUS( - "INJECTED_CUDF_EXCEPTION", thread_id, thread->second.task_id, thread->second.state); - thread->second.record_failed_retry_time(); + "INJECTED_CUDF_EXCEPTION", thread_id, thread->second->task_id, thread->second->state); + thread->second->record_failed_retry_time(); throw_java_exception(cudf::jni::CUDF_EXCEPTION_CLASS, "injected CudfException"); } - if (thread->second.split_and_retry_oom.matches(is_for_cpu)) { - if (thread->second.split_and_retry_oom.skip_count > 0) { - thread->second.split_and_retry_oom.skip_count--; - } else if (thread->second.split_and_retry_oom.hit_count > 0) { - thread->second.split_and_retry_oom.hit_count--; - thread->second.metrics.num_times_split_retry_throw++; + if (thread->second->split_and_retry_oom.matches(is_for_cpu)) { + if (thread->second->split_and_retry_oom.skip_count > 0) { + thread->second->split_and_retry_oom.skip_count--; + } else if (thread->second->split_and_retry_oom.hit_count > 0) { + thread->second->split_and_retry_oom.hit_count--; + thread->second->metrics.num_times_split_retry_throw++; std::string const op_prefix = "INJECTED_SPLIT_AND_RETRY_OOM_"; std::string const op = op_prefix + (is_for_cpu ? "CPU" : "GPU"); - LOG_STATUS(op, thread_id, thread->second.task_id, thread->second.state); - thread->second.record_failed_retry_time(); + LOG_STATUS(op, thread_id, thread->second->task_id, thread->second->state); + thread->second->record_failed_retry_time(); if (is_for_cpu) { throw_java_exception(CPU_SPLIT_AND_RETRY_OOM_CLASS, "injected SplitAndRetryOOM"); } else { @@ -1611,15 +1591,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (blocking) { block_thread_until_ready(thread_id, lock); } - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_RUNNING: transition(thread->second, thread_state::THREAD_ALLOC); - thread->second.is_cpu_alloc = is_for_cpu; + thread->second->is_cpu_alloc = is_for_cpu; break; default: { std::stringstream ss; ss << "thread " << thread_id << " in unexpected state pre alloc " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } @@ -1657,41 +1637,37 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const thread = threads.find(thread_id); if (!was_recursive && thread != threads.end()) { // The allocation succeeded so we are no longer doing a retry - if (thread->second.is_retry_alloc_before_bufn) { - thread->second.is_retry_alloc_before_bufn = false; - LOG_STATUS( - "DETAIL", - thread_id, - thread->second.task_id, - thread->second.state, + if (thread->second->is_retry_alloc_before_bufn) { + thread->second->is_retry_alloc_before_bufn = false; + LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_success_core", thread_id); } - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_ALLOC: // fall through case thread_state::THREAD_ALLOC_FREE: - if (thread->second.is_cpu_alloc != is_for_cpu) { + if (thread->second->is_cpu_alloc != is_for_cpu) { std::stringstream ss; ss << "thread " << thread_id << " has a mismatch on CPU vs GPU post alloc " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.is_cpu_alloc = false; + thread->second->is_cpu_alloc = false; // num_bytes is likely not padded, which could cause slight inaccuracies // but for now it shouldn't matter for watermark purposes if (!is_for_cpu) { - if (!thread->second.is_in_spilling) { - thread->second.metrics.gpu_memory_active_footprint += num_bytes; - thread->second.metrics.gpu_memory_max_footprint = - std::max(thread->second.metrics.gpu_memory_active_footprint, - thread->second.metrics.gpu_memory_max_footprint); + if (!thread->second->is_in_spilling) { + thread->second->metrics.gpu_memory_active_footprint += num_bytes; + thread->second->metrics.gpu_memory_max_footprint = + std::max(thread->second->metrics.gpu_memory_active_footprint, + thread->second->metrics.gpu_memory_max_footprint); } gpu_memory_allocated_bytes += num_bytes; - thread->second.metrics.gpu_max_memory_allocated = - std::max(thread->second.metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); + thread->second->metrics.gpu_max_memory_allocated = + std::max(thread->second->metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); } break; default: break; @@ -1717,9 +1693,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread_priority to_wake(-1, -1); bool is_to_wake_set = false; for (auto const& [thread_d, t_state] : threads) { - thread_state const& state = t_state.state; - if (state == thread_state::THREAD_BLOCKED && is_for_cpu == t_state.is_cpu_alloc) { - thread_priority current = t_state.priority(); + thread_state const& state = t_state->state; + if (state == thread_state::THREAD_BLOCKED && is_for_cpu == t_state->is_cpu_alloc) { + thread_priority current = t_state->priority(); if (!is_to_wake_set || to_wake < current) { to_wake = current; is_to_wake_set = true; @@ -1731,15 +1707,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (thread_id_to_wake > 0) { auto const thread = threads.find(thread_id_to_wake); if (thread != threads.end()) { - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_BLOCKED: transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); break; default: { std::stringstream ss; ss << "internal error expected to only wake up blocked threads " << thread_id_to_wake - << " " << as_str(thread->second.state); + << " " << as_str(thread->second->state); throw std::runtime_error(ss.str()); } } @@ -1765,10 +1741,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread_priority to_wake(-1, -1); bool is_to_wake_set = false; for (auto const& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BUFN: { - if (is_for_cpu == t_state.is_cpu_alloc) { - thread_priority current = t_state.priority(); + if (is_for_cpu == t_state->is_cpu_alloc) { + thread_priority current = t_state->priority(); if (!is_to_wake_set || to_wake < current) { to_wake = current; is_to_wake_set = true; @@ -1787,10 +1763,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const this_id = static_cast(pthread_self()); auto const thread = threads.find(thread_id_to_wake); if (thread != threads.end() && thread->first != this_id) { - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_BUFN: transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); break; case thread_state::THREAD_BUFN_WAIT: transition(thread->second, thread_state::THREAD_RUNNING); @@ -1804,7 +1780,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { default: { std::stringstream ss; ss << "internal error expected to only wake up blocked threads " - << thread_id_to_wake << " " << as_str(thread->second.state); + << thread_id_to_wake << " " << as_str(thread->second->state); throw std::runtime_error(ss.str()); } } @@ -1815,13 +1791,13 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } } - bool is_thread_bufn_or_above(JNIEnv* env, full_thread_state const& state) + bool is_thread_bufn_or_above(JNIEnv* env, std::shared_ptr state) { bool ret = false; - if (state.pool_blocked) { + if (state->pool_blocked) { ret = true; } else { - switch (state.state) { + switch (state->state) { case thread_state::THREAD_BLOCKED: ret = false; break; case thread_state::THREAD_BUFN: // empty we are looking for even a single thread that is not blocked @@ -1829,7 +1805,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { break; default: ret = env->CallStaticBooleanMethod( - ThreadStateRegistry_jclass, isThreadBlocked_method, state.thread_id); + ThreadStateRegistry_jclass, isThreadBlocked_method, state->thread_id); break; } } @@ -1894,12 +1870,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // We are going to do two passes through the threads to deal with this. // First pass is to look at the dedicated task threads for (auto const& [thread_id, t_state] : threads) { - long const task_id = t_state.task_id; + long const task_id = t_state->task_id; if (task_id >= 0) { all_task_ids.insert(task_id); bool const is_bufn_plus = is_thread_bufn_or_above(env, t_state); if (is_bufn_plus) { bufn_task_ids.insert(task_id); } - if (is_bufn_plus || t_state.state == thread_state::THREAD_BLOCKED) { + if (is_bufn_plus || t_state->state == thread_state::THREAD_BLOCKED) { blocked_task_ids.insert(task_id); } } @@ -1907,9 +1883,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // Second pass is to look at the pool threads for (auto const& [thread_id, t_state] : threads) { - long const is_pool_thread = t_state.task_id < 0; + long const is_pool_thread = t_state->task_id < 0; if (is_pool_thread) { - for (auto const& task_id : t_state.pool_task_ids) { + for (auto const& task_id : t_state->pool_task_ids) { auto const it = pool_task_thread_count.find(task_id); if (it != pool_task_thread_count.end()) { it->second += 1; @@ -1920,7 +1896,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { bool const is_bufn_plus = is_thread_bufn_or_above(env, t_state); if (is_bufn_plus) { - for (auto const& task_id : t_state.pool_task_ids) { + for (auto const& task_id : t_state->pool_task_ids) { auto const it = pool_bufn_task_thread_count.find(task_id); if (it != pool_bufn_task_thread_count.end()) { it->second += 1; @@ -1929,8 +1905,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } } } - if (!is_bufn_plus && t_state.state != thread_state::THREAD_BLOCKED) { - for (auto const& task_id : t_state.pool_task_ids) { + if (!is_bufn_plus && t_state->state != thread_state::THREAD_BLOCKED) { + for (auto const& task_id : t_state->pool_task_ids) { blocked_task_ids.erase(task_id); } } @@ -1993,10 +1969,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { bool is_to_bufn_set = false; int blocked_thread_count = 0; for (auto const& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BLOCKED: { blocked_thread_count++; - thread_priority const& current = t_state.priority(); + thread_priority const& current = t_state->priority(); if (!is_to_bufn_set || current < to_bufn) { to_bufn = current; is_to_bufn_set = true; @@ -2016,19 +1992,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // But we are not tracking when data is made spillable // so if data was made spillable we will retry the // allocation, instead of going to BUFN. - thread->second.is_retry_alloc_before_bufn = true; - LOG_STATUS("DETAIL", - thread_id_to_bufn, - thread->second.task_id, - thread->second.state, - "thread (id: {}) is_retry_alloc_before_bufn set to true", - thread_id_to_bufn); + thread->second->is_retry_alloc_before_bufn = true; + LOG_STATUS("DETAIL", thread_id_to_bufn, thread->second->task_id, thread->second->state, + "thread (id: {}) is_retry_alloc_before_bufn set to true", thread_id_to_bufn); transition(thread->second, thread_state::THREAD_RUNNING); } else { log_all_threads_states(); transition(thread->second, thread_state::THREAD_BUFN_THROW); } - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); } } // We now need a way to detect if we need to split the input and retry. @@ -2057,9 +2029,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread_priority to_wake(-1, -1); bool is_to_wake_set = false; for (auto const& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BUFN: { - thread_priority const& current = t_state.priority(); + thread_priority const& current = t_state->priority(); if (!is_to_wake_set || to_wake < current) { to_wake = current; is_to_wake_set = true; @@ -2072,7 +2044,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const found_thread = threads.find(thread_id); if (found_thread != threads.end()) { transition(found_thread->second, thread_state::THREAD_SPLIT_THROW); - found_thread->second.wake_condition->notify_all(); + found_thread->second->wake_condition->notify_all(); } } } @@ -2100,40 +2072,32 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // only retry if this was due to an out of memory exception. bool ret = true; if (!was_recursive && thread != threads.end()) { - if (thread->second.is_cpu_alloc != is_for_cpu) { + if (thread->second->is_cpu_alloc != is_for_cpu) { std::stringstream ss; ss << "thread " << thread_id << " has a mismatch on CPU vs GPU post alloc " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_ALLOC_FREE: transition(thread->second, thread_state::THREAD_RUNNING); break; case thread_state::THREAD_ALLOC: - if (is_oom && thread->second.is_retry_alloc_before_bufn) { - if (thread->second.is_retry_alloc_before_bufn) { - thread->second.is_retry_alloc_before_bufn = false; - LOG_STATUS( - "DETAIL", - thread_id, - thread->second.task_id, - thread->second.state, + if (is_oom && thread->second->is_retry_alloc_before_bufn) { + if (thread->second->is_retry_alloc_before_bufn) { + thread->second->is_retry_alloc_before_bufn = false; + LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } transition(thread->second, thread_state::THREAD_BUFN_THROW); - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); } else if (is_oom && blocking) { - if (thread->second.is_retry_alloc_before_bufn) { - thread->second.is_retry_alloc_before_bufn = false; - LOG_STATUS( - "DETAIL", - thread_id, - thread->second.task_id, - thread->second.state, + if (thread->second->is_retry_alloc_before_bufn) { + thread->second->is_retry_alloc_before_bufn = false; + LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } @@ -2146,7 +2110,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { default: { std::stringstream ss; ss << "Internal error: unexpected state after alloc failed " << thread_id << " " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::runtime_error(ss.str()); } } @@ -2188,10 +2152,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const tid = static_cast(pthread_self()); auto const thread = threads.find(tid); if (thread != threads.end()) { - LOG_STATUS("DEALLOC", tid, thread->second.task_id, thread->second.state); + LOG_STATUS("DEALLOC", tid, thread->second->task_id, thread->second->state); if (!is_for_cpu) { - if (!thread->second.is_in_spilling) { - thread->second.metrics.gpu_memory_active_footprint -= num_bytes; + if (!thread->second->is_in_spilling) { + thread->second->metrics.gpu_memory_active_footprint -= num_bytes; } gpu_memory_allocated_bytes -= num_bytes; } @@ -2211,10 +2175,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // By not changing our thread's state to THREAD_ALLOC_FREE, we keep the state // the same, but we still let other threads know that there was a free and they should // handle accordingly. - if (t_state.thread_id != tid) { - switch (t_state.state) { + if (t_state->thread_id != tid) { + switch (t_state->state) { case thread_state::THREAD_ALLOC: - if (is_for_cpu == t_state.is_cpu_alloc) { + if (is_for_cpu == t_state->is_cpu_alloc) { transition(t_state, thread_state::THREAD_ALLOC_FREE); } break; From ad0ef971e781972d0d80387d18a5685821e9e023 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Nov 2025 09:32:07 -0800 Subject: [PATCH 2/6] style fixes --- src/main/cpp/src/SparkResourceAdaptorJni.cpp | 79 ++++++++++++++------ 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 292f0aa163..f526574d58 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -681,22 +681,29 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const found = threads.find(thread_id); if (found != threads.end()) { if (found->second->task_id >= 0 && found->second->task_id != task_id) { - LOG_STATUS("FIXUP", thread_id, found->second->task_id, found->second->state, - "desired task_id {}", task_id); + LOG_STATUS("FIXUP", + thread_id, + found->second->task_id, + found->second->state, + "desired task_id {}", + task_id); remove_thread_association(thread_id, found->second->task_id, lock); } } auto const was_threads_inserted = threads.emplace( - thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id, task_id)); + thread_id, + std::make_shared(thread_state::THREAD_RUNNING, thread_id, task_id)); if (was_threads_inserted.second == false) { if (was_threads_inserted.first->second->state == thread_state::THREAD_REMOVE_THROW) { std::stringstream ss; - ss << "A thread " << thread_id << " is shutting down " + ss << "A thread " << thread_id << " is shutting down " << was_threads_inserted.first->second->task_id << " vs " << task_id; auto const msg = ss.str(); - LOG_STATUS("ERROR", - thread_id, was_threads_inserted.first->second->task_id, was_threads_inserted.first->second->state, - msg); + LOG_STATUS("ERROR", + thread_id, + was_threads_inserted.first->second->task_id, + was_threads_inserted.first->second->state, + msg); throw std::invalid_argument(msg); } @@ -768,8 +775,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); if (shutting_down) { throw std::runtime_error("spark_resource_adaptor is shutting down"); } - auto const was_inserted = - threads.emplace(thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id)); + auto const was_inserted = threads.emplace( + thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id)); if (was_inserted.second == true) { was_inserted.first->second->is_for_shuffle = is_for_shuffle; LOG_TRANSITION(thread_id, -1, thread_state::UNKNOWN, thread_state::THREAD_RUNNING); @@ -793,7 +800,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { checkpoint_metrics(was_inserted.first->second); was_inserted.first->second->pool_task_ids.insert(task_ids.begin(), task_ids.end()); - LOG_STATUS_CONTAINER("ADD_TASKS", thread_id, -1, was_inserted.first->second->state, "CURRENT IDs", was_inserted.first->second->pool_task_ids); + LOG_STATUS_CONTAINER("ADD_TASKS", + thread_id, + -1, + was_inserted.first->second->state, + "CURRENT IDs", + was_inserted.first->second->pool_task_ids); } void pool_thread_finished_for_tasks(long const thread_id, @@ -811,7 +823,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const& id : task_ids) { thread->second->pool_task_ids.erase(id); } - LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, thread->second->state, "CURRENT IDs", thread->second->pool_task_ids); + LOG_STATUS_CONTAINER("REMOVE_TASKS", + thread_id, + -1, + thread->second->state, + "CURRENT IDs", + thread->second->pool_task_ids); if (thread->second->pool_task_ids.empty()) { if (remove_thread_association(thread_id, -1, lock)) { wake_up_threads_after_task_finishes(lock); @@ -861,7 +878,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const thread = threads.find(thread_id); if (thread != threads.end()) { if (thread->second->pool_task_ids.erase(task_id) != 0) { - LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, thread->second->state, "CURRENT IDs", thread->second->pool_task_ids); + LOG_STATUS_CONTAINER("REMOVE_TASKS", + thread_id, + -1, + thread->second->state, + "CURRENT IDs", + thread->second->pool_task_ids); if (thread->second->pool_task_ids.empty()) { run_checks = remove_thread_association(thread_id, task_id, lock) || run_checks; } @@ -1210,8 +1232,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { * of setting the state directly. This will log the transition and do a little bit of * verification. */ - void transition(std::shared_ptr state, - thread_state const new_state) + void transition(std::shared_ptr state, thread_state const new_state) { thread_state original = state->state; state->transition_to(new_state); @@ -1639,7 +1660,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // The allocation succeeded so we are no longer doing a retry if (thread->second->is_retry_alloc_before_bufn) { thread->second->is_retry_alloc_before_bufn = false; - LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, + LOG_STATUS( + "DETAIL", + thread_id, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_success_core", thread_id); } @@ -1666,8 +1691,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread->second->metrics.gpu_memory_max_footprint); } gpu_memory_allocated_bytes += num_bytes; - thread->second->metrics.gpu_max_memory_allocated = - std::max(thread->second->metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); + thread->second->metrics.gpu_max_memory_allocated = std::max( + thread->second->metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); } break; default: break; @@ -1993,8 +2018,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // so if data was made spillable we will retry the // allocation, instead of going to BUFN. thread->second->is_retry_alloc_before_bufn = true; - LOG_STATUS("DETAIL", thread_id_to_bufn, thread->second->task_id, thread->second->state, - "thread (id: {}) is_retry_alloc_before_bufn set to true", thread_id_to_bufn); + LOG_STATUS("DETAIL", + thread_id_to_bufn, + thread->second->task_id, + thread->second->state, + "thread (id: {}) is_retry_alloc_before_bufn set to true", + thread_id_to_bufn); transition(thread->second, thread_state::THREAD_RUNNING); } else { log_all_threads_states(); @@ -2088,7 +2117,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (is_oom && thread->second->is_retry_alloc_before_bufn) { if (thread->second->is_retry_alloc_before_bufn) { thread->second->is_retry_alloc_before_bufn = false; - LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, + LOG_STATUS( + "DETAIL", + thread_id, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } @@ -2097,7 +2130,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } else if (is_oom && blocking) { if (thread->second->is_retry_alloc_before_bufn) { thread->second->is_retry_alloc_before_bufn = false; - LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, + LOG_STATUS( + "DETAIL", + thread_id, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } From 6550540f1e1ea52bd70de7a4e55c876bcbda8a60 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Nov 2025 09:32:07 -0800 Subject: [PATCH 3/6] style fixes Signed-off-by: Alessandro Bellina --- src/main/cpp/src/SparkResourceAdaptorJni.cpp | 79 ++++++++++++++------ 1 file changed, 58 insertions(+), 21 deletions(-) diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 292f0aa163..f526574d58 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -681,22 +681,29 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const found = threads.find(thread_id); if (found != threads.end()) { if (found->second->task_id >= 0 && found->second->task_id != task_id) { - LOG_STATUS("FIXUP", thread_id, found->second->task_id, found->second->state, - "desired task_id {}", task_id); + LOG_STATUS("FIXUP", + thread_id, + found->second->task_id, + found->second->state, + "desired task_id {}", + task_id); remove_thread_association(thread_id, found->second->task_id, lock); } } auto const was_threads_inserted = threads.emplace( - thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id, task_id)); + thread_id, + std::make_shared(thread_state::THREAD_RUNNING, thread_id, task_id)); if (was_threads_inserted.second == false) { if (was_threads_inserted.first->second->state == thread_state::THREAD_REMOVE_THROW) { std::stringstream ss; - ss << "A thread " << thread_id << " is shutting down " + ss << "A thread " << thread_id << " is shutting down " << was_threads_inserted.first->second->task_id << " vs " << task_id; auto const msg = ss.str(); - LOG_STATUS("ERROR", - thread_id, was_threads_inserted.first->second->task_id, was_threads_inserted.first->second->state, - msg); + LOG_STATUS("ERROR", + thread_id, + was_threads_inserted.first->second->task_id, + was_threads_inserted.first->second->state, + msg); throw std::invalid_argument(msg); } @@ -768,8 +775,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); if (shutting_down) { throw std::runtime_error("spark_resource_adaptor is shutting down"); } - auto const was_inserted = - threads.emplace(thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id)); + auto const was_inserted = threads.emplace( + thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id)); if (was_inserted.second == true) { was_inserted.first->second->is_for_shuffle = is_for_shuffle; LOG_TRANSITION(thread_id, -1, thread_state::UNKNOWN, thread_state::THREAD_RUNNING); @@ -793,7 +800,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { checkpoint_metrics(was_inserted.first->second); was_inserted.first->second->pool_task_ids.insert(task_ids.begin(), task_ids.end()); - LOG_STATUS_CONTAINER("ADD_TASKS", thread_id, -1, was_inserted.first->second->state, "CURRENT IDs", was_inserted.first->second->pool_task_ids); + LOG_STATUS_CONTAINER("ADD_TASKS", + thread_id, + -1, + was_inserted.first->second->state, + "CURRENT IDs", + was_inserted.first->second->pool_task_ids); } void pool_thread_finished_for_tasks(long const thread_id, @@ -811,7 +823,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const& id : task_ids) { thread->second->pool_task_ids.erase(id); } - LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, thread->second->state, "CURRENT IDs", thread->second->pool_task_ids); + LOG_STATUS_CONTAINER("REMOVE_TASKS", + thread_id, + -1, + thread->second->state, + "CURRENT IDs", + thread->second->pool_task_ids); if (thread->second->pool_task_ids.empty()) { if (remove_thread_association(thread_id, -1, lock)) { wake_up_threads_after_task_finishes(lock); @@ -861,7 +878,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const thread = threads.find(thread_id); if (thread != threads.end()) { if (thread->second->pool_task_ids.erase(task_id) != 0) { - LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, thread->second->state, "CURRENT IDs", thread->second->pool_task_ids); + LOG_STATUS_CONTAINER("REMOVE_TASKS", + thread_id, + -1, + thread->second->state, + "CURRENT IDs", + thread->second->pool_task_ids); if (thread->second->pool_task_ids.empty()) { run_checks = remove_thread_association(thread_id, task_id, lock) || run_checks; } @@ -1210,8 +1232,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { * of setting the state directly. This will log the transition and do a little bit of * verification. */ - void transition(std::shared_ptr state, - thread_state const new_state) + void transition(std::shared_ptr state, thread_state const new_state) { thread_state original = state->state; state->transition_to(new_state); @@ -1639,7 +1660,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // The allocation succeeded so we are no longer doing a retry if (thread->second->is_retry_alloc_before_bufn) { thread->second->is_retry_alloc_before_bufn = false; - LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, + LOG_STATUS( + "DETAIL", + thread_id, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_success_core", thread_id); } @@ -1666,8 +1691,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread->second->metrics.gpu_memory_max_footprint); } gpu_memory_allocated_bytes += num_bytes; - thread->second->metrics.gpu_max_memory_allocated = - std::max(thread->second->metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); + thread->second->metrics.gpu_max_memory_allocated = std::max( + thread->second->metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); } break; default: break; @@ -1993,8 +2018,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // so if data was made spillable we will retry the // allocation, instead of going to BUFN. thread->second->is_retry_alloc_before_bufn = true; - LOG_STATUS("DETAIL", thread_id_to_bufn, thread->second->task_id, thread->second->state, - "thread (id: {}) is_retry_alloc_before_bufn set to true", thread_id_to_bufn); + LOG_STATUS("DETAIL", + thread_id_to_bufn, + thread->second->task_id, + thread->second->state, + "thread (id: {}) is_retry_alloc_before_bufn set to true", + thread_id_to_bufn); transition(thread->second, thread_state::THREAD_RUNNING); } else { log_all_threads_states(); @@ -2088,7 +2117,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (is_oom && thread->second->is_retry_alloc_before_bufn) { if (thread->second->is_retry_alloc_before_bufn) { thread->second->is_retry_alloc_before_bufn = false; - LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, + LOG_STATUS( + "DETAIL", + thread_id, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } @@ -2097,7 +2130,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } else if (is_oom && blocking) { if (thread->second->is_retry_alloc_before_bufn) { thread->second->is_retry_alloc_before_bufn = false; - LOG_STATUS("DETAIL", thread_id, thread->second->task_id, thread->second->state, + LOG_STATUS( + "DETAIL", + thread_id, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } From 4920e46a81cd4ca750691a2f93d3a3f347f61971 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Nov 2025 11:01:07 -0800 Subject: [PATCH 4/6] Fix a pre-existing bug for a corner case --- src/main/cpp/src/SparkResourceAdaptorJni.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index f526574d58..40b86ce359 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -1361,7 +1361,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); } while (thread != threads.end() && is_blocked(thread->second->state)); - thread->second->after_block(); + if (thread != threads.end()) { + thread->second->after_block(); + } task_has_woken_condition.notify_all(); break; case thread_state::THREAD_BUFN_THROW: @@ -1384,7 +1386,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); } while (thread != threads.end() && is_blocked(thread->second->state)); - thread->second->after_block(); + if (thread != threads.end()) { + thread->second->after_block(); + } task_has_woken_condition.notify_all(); } break; From b46da69819f467da623365f27efcdec489829146 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Nov 2025 11:05:41 -0800 Subject: [PATCH 5/6] fix styles --- src/main/cpp/src/SparkResourceAdaptorJni.cpp | 8 ++------ thirdparty/cudf | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 40b86ce359..e7f70d2b43 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -1361,9 +1361,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); } while (thread != threads.end() && is_blocked(thread->second->state)); - if (thread != threads.end()) { - thread->second->after_block(); - } + if (thread != threads.end()) { thread->second->after_block(); } task_has_woken_condition.notify_all(); break; case thread_state::THREAD_BUFN_THROW: @@ -1386,9 +1384,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); } while (thread != threads.end() && is_blocked(thread->second->state)); - if (thread != threads.end()) { - thread->second->after_block(); - } + if (thread != threads.end()) { thread->second->after_block(); } task_has_woken_condition.notify_all(); } break; diff --git a/thirdparty/cudf b/thirdparty/cudf index bd2af5e9fd..d27e318977 160000 --- a/thirdparty/cudf +++ b/thirdparty/cudf @@ -1 +1 @@ -Subproject commit bd2af5e9fd73cb861caffe482069fd714e5cd13a +Subproject commit d27e318977248edc3c0a479f21d7e3cb7ec0c92b From 0685885c3a8f99121d3d1faed17fea4e043a4906 Mon Sep 17 00:00:00 2001 From: Alessandro Bellina Date: Tue, 18 Nov 2025 11:39:15 -0800 Subject: [PATCH 6/6] Resolve thirparty/cudf ref --- thirdparty/cudf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thirdparty/cudf b/thirdparty/cudf index d27e318977..3745c9f133 160000 --- a/thirdparty/cudf +++ b/thirdparty/cudf @@ -1 +1 @@ -Subproject commit d27e318977248edc3c0a479f21d7e3cb7ec0c92b +Subproject commit 3745c9f133eb3d204810090c05a5701643028c1b