diff --git a/src/main/cpp/src/SparkResourceAdaptorJni.cpp b/src/main/cpp/src/SparkResourceAdaptorJni.cpp index 63a2c43807..e7f70d2b43 100644 --- a/src/main/cpp/src/SparkResourceAdaptorJni.cpp +++ b/src/main/cpp/src/SparkResourceAdaptorJni.cpp @@ -680,41 +680,42 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (shutting_down) { throw std::runtime_error("spark_resource_adaptor is shutting down"); } auto const found = threads.find(thread_id); if (found != threads.end()) { - if (found->second.task_id >= 0 && found->second.task_id != task_id) { + if (found->second->task_id >= 0 && found->second->task_id != task_id) { LOG_STATUS("FIXUP", thread_id, - found->second.task_id, - found->second.state, + found->second->task_id, + found->second->state, "desired task_id {}", task_id); - remove_thread_association(thread_id, found->second.task_id, lock); + remove_thread_association(thread_id, found->second->task_id, lock); } } auto const was_threads_inserted = threads.emplace( - thread_id, full_thread_state(thread_state::THREAD_RUNNING, thread_id, task_id)); + thread_id, + std::make_shared(thread_state::THREAD_RUNNING, thread_id, task_id)); if (was_threads_inserted.second == false) { - if (was_threads_inserted.first->second.state == thread_state::THREAD_REMOVE_THROW) { + if (was_threads_inserted.first->second->state == thread_state::THREAD_REMOVE_THROW) { std::stringstream ss; ss << "A thread " << thread_id << " is shutting down " - << was_threads_inserted.first->second.task_id << " vs " << task_id; + << was_threads_inserted.first->second->task_id << " vs " << task_id; auto const msg = ss.str(); LOG_STATUS("ERROR", thread_id, - was_threads_inserted.first->second.task_id, - was_threads_inserted.first->second.state, + was_threads_inserted.first->second->task_id, + was_threads_inserted.first->second->state, msg); throw std::invalid_argument(msg); } - if (was_threads_inserted.first->second.task_id != task_id) { + if (was_threads_inserted.first->second->task_id != task_id) { std::stringstream ss; ss << "A thread " << thread_id << " can only be dedicated to a single task." - << was_threads_inserted.first->second.task_id << " != " << task_id; + << was_threads_inserted.first->second->task_id << " != " << task_id; auto const msg = ss.str(); LOG_STATUS("ERROR", thread_id, - was_threads_inserted.first->second.task_id, - was_threads_inserted.first->second.state, + was_threads_inserted.first->second->task_id, + was_threads_inserted.first->second->state, msg); throw std::invalid_argument(msg); } @@ -742,21 +743,21 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); - if (thread != threads.end()) { thread->second.reset_retry_state(true); } + if (thread != threads.end()) { thread->second->reset_retry_state(true); } } void end_retry_block(long const thread_id) { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); - if (thread != threads.end()) { thread->second.reset_retry_state(false); } + if (thread != threads.end()) { thread->second->reset_retry_state(false); } } bool is_working_on_task_as_pool_thread(long const thread_id) { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); - if (thread != threads.end()) { return !thread->second.pool_task_ids.empty(); } + if (thread != threads.end()) { return !thread->second->pool_task_ids.empty(); } return false; } @@ -774,16 +775,16 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); if (shutting_down) { throw std::runtime_error("spark_resource_adaptor is shutting down"); } - auto const was_inserted = - threads.emplace(thread_id, full_thread_state(thread_state::THREAD_RUNNING, thread_id)); + auto const was_inserted = threads.emplace( + thread_id, std::make_shared(thread_state::THREAD_RUNNING, thread_id)); if (was_inserted.second == true) { - was_inserted.first->second.is_for_shuffle = is_for_shuffle; + was_inserted.first->second->is_for_shuffle = is_for_shuffle; LOG_TRANSITION(thread_id, -1, thread_state::UNKNOWN, thread_state::THREAD_RUNNING); - } else if (was_inserted.first->second.task_id != -1) { + } else if (was_inserted.first->second->task_id != -1) { throw std::invalid_argument("the thread is associated with a non-pool task already"); - } else if (was_inserted.first->second.state == thread_state::THREAD_REMOVE_THROW) { + } else if (was_inserted.first->second->state == thread_state::THREAD_REMOVE_THROW) { throw std::invalid_argument("the thread is in the process of shutting down."); - } else if (was_inserted.first->second.is_for_shuffle != is_for_shuffle) { + } else if (was_inserted.first->second->is_for_shuffle != is_for_shuffle) { if (is_for_shuffle) { throw std::invalid_argument( "the thread is marked as a non-shuffle thread, and we cannot change it while there are " @@ -798,13 +799,13 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // save the metrics for all tasks before we add any new ones. checkpoint_metrics(was_inserted.first->second); - was_inserted.first->second.pool_task_ids.insert(task_ids.begin(), task_ids.end()); + was_inserted.first->second->pool_task_ids.insert(task_ids.begin(), task_ids.end()); LOG_STATUS_CONTAINER("ADD_TASKS", thread_id, -1, - was_inserted.first->second.state, + was_inserted.first->second->state, "CURRENT IDs", - was_inserted.first->second.pool_task_ids); + was_inserted.first->second->pool_task_ids); } void pool_thread_finished_for_tasks(long const thread_id, @@ -820,15 +821,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // Now drop the tasks from the pool for (auto const& id : task_ids) { - thread->second.pool_task_ids.erase(id); + thread->second->pool_task_ids.erase(id); } LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, - thread->second.state, + thread->second->state, "CURRENT IDs", - thread->second.pool_task_ids); - if (thread->second.pool_task_ids.empty()) { + thread->second->pool_task_ids); + if (thread->second->pool_task_ids.empty()) { if (remove_thread_association(thread_id, -1, lock)) { wake_up_threads_after_task_finishes(lock); } @@ -876,14 +877,14 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const& thread_id : thread_ids) { auto const thread = threads.find(thread_id); if (thread != threads.end()) { - if (thread->second.pool_task_ids.erase(task_id) != 0) { + if (thread->second->pool_task_ids.erase(task_id) != 0) { LOG_STATUS_CONTAINER("REMOVE_TASKS", thread_id, -1, - thread->second.state, + thread->second->state, "CURRENT IDs", - thread->second.pool_task_ids); - if (thread->second.pool_task_ids.empty()) { + thread->second->pool_task_ids); + if (thread->second->pool_task_ids.empty()) { run_checks = remove_thread_association(thread_id, task_id, lock) || run_checks; } } @@ -964,7 +965,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - threads_at->second.retry_oom.init(num_ooms, skip_count, oom_filter); + threads_at->second->retry_oom.init(num_ooms, skip_count, oom_filter); } else { throw std::invalid_argument("the thread is not associated with any task/shuffle"); } @@ -982,7 +983,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - threads_at->second.split_and_retry_oom.init(num_ooms, skip_count, oom_filter); + threads_at->second->split_and_retry_oom.init(num_ooms, skip_count, oom_filter); } else { throw std::invalid_argument("the thread is not associated with any task/shuffle"); } @@ -997,7 +998,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - threads_at->second.cudf_exception_injected = num_times; + threads_at->second->cudf_exception_injected = num_times; } else { throw std::invalid_argument("the thread is not associated with any task/shuffle"); } @@ -1018,8 +1019,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const thread_id : task_at->second) { auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - ret += (threads_at->second.metrics.*MetricPtr); - (threads_at->second.metrics.*MetricPtr) = 0; + ret += (threads_at->second->metrics.*MetricPtr); + (threads_at->second->metrics.*MetricPtr) = 0; } } } @@ -1041,7 +1042,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (task_at != task_to_threads.end()) { for (auto const thread_id : task_at->second) { auto const threads_at = threads.find(thread_id); - if (threads_at != threads.end()) { ret += (threads_at->second.metrics.*MetricPtr); } + if (threads_at != threads.end()) { ret += (threads_at->second->metrics.*MetricPtr); } } } @@ -1109,8 +1110,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { for (auto const thread_id : task_at->second) { auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - ret += threads_at->second.currently_blocked_for(); - ret += threads_at->second.metrics.time_lost_or_blocked; + ret += threads_at->second->currently_blocked_for(); + ret += threads_at->second->metrics.time_lost_or_blocked; } } } @@ -1165,7 +1166,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const tid = static_cast(pthread_self()); auto const thread = threads.find(tid); - if (thread != threads.end()) { thread->second.is_in_spilling = true; } + if (thread != threads.end()) { thread->second->is_in_spilling = true; } } void spill_range_done() @@ -1173,7 +1174,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const tid = static_cast(pthread_self()); auto const thread = threads.find(tid); - if (thread != threads.end()) { thread->second.is_in_spilling = false; } + if (thread != threads.end()) { thread->second->is_in_spilling = false; } } /** @@ -1198,7 +1199,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const threads_at = threads.find(thread_id); if (threads_at != threads.end()) { - return static_cast(threads_at->second.state); + return static_cast(threads_at->second->state); } else { return -1; } @@ -1212,7 +1213,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // from an operation. std::mutex state_mutex; std::condition_variable task_has_woken_condition; - std::map threads; + std::map> threads; std::map> task_to_threads; long gpu_memory_allocated_bytes = 0; @@ -1231,11 +1232,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { * of setting the state directly. This will log the transition and do a little bit of * verification. */ - void transition(full_thread_state& state, thread_state const new_state) + void transition(std::shared_ptr state, thread_state const new_state) { - thread_state original = state.state; - state.transition_to(new_state); - LOG_TRANSITION(state.thread_id, state.task_id, original, new_state); + thread_state original = state->state; + state->transition_to(new_state); + LOG_TRANSITION(state->thread_id, state->task_id, original, new_state); } /** @@ -1252,7 +1253,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::unique_lock lock(state_mutex); auto const thread = threads.find(thread_id); long task_id = -1; - if (thread != threads.end()) { task_id = thread->second.task_id; } + if (thread != threads.end()) { task_id = thread->second->task_id; } if (task_id < 0) { std::stringstream ss; @@ -1260,24 +1261,24 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { throw std::invalid_argument(ss.str()); } - thread->second.pool_blocked = pool_blocked; + thread->second->pool_blocked = pool_blocked; } /** * Checkpoint all of the metrics for a thread. */ - void checkpoint_metrics(full_thread_state& state) + void checkpoint_metrics(std::shared_ptr state) { - if (state.task_id < 0) { + if (state->task_id < 0) { // save the metrics for all tasks before we add any new ones. - for (auto const task_id : state.pool_task_ids) { + for (auto const task_id : state->pool_task_ids) { auto const metrics_at = task_to_metrics.try_emplace(task_id, task_metrics()); - metrics_at.first->second.add(state.metrics); + metrics_at.first->second.add(state->metrics); } - state.metrics.clear(); + state->metrics.clear(); } else { - auto const metrics_at = task_to_metrics.try_emplace(state.task_id, task_metrics()); - metrics_at.first->second.take_from(state.metrics); + auto const metrics_at = task_to_metrics.try_emplace(state->task_id, task_metrics()); + metrics_at.first->second.take_from(state->metrics); } } @@ -1350,23 +1351,23 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { while (!done) { auto thread = threads.find(thread_id); if (thread != threads.end()) { - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_BLOCKED: // fall through case thread_state::THREAD_BUFN: - LOG_STATUS("WAITING", thread_id, thread->second.task_id, thread->second.state); - thread->second.before_block(); + LOG_STATUS("WAITING", thread_id, thread->second->task_id, thread->second->state); + thread->second->before_block(); do { - thread->second.wake_condition->wait(lock); + thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); - } while (thread != threads.end() && is_blocked(thread->second.state)); - thread->second.after_block(); + } while (thread != threads.end() && is_blocked(thread->second->state)); + if (thread != threads.end()) { thread->second->after_block(); } task_has_woken_condition.notify_all(); break; case thread_state::THREAD_BUFN_THROW: transition(thread->second, thread_state::THREAD_BUFN_WAIT); - thread->second.record_failed_retry_time(); - throw_retry_oom("rollback and retry operation", thread->second, lock); + thread->second->record_failed_retry_time(); + throw_retry_oom("rollback and retry operation", *thread->second, lock); break; case thread_state::THREAD_BUFN_WAIT: transition(thread->second, thread_state::THREAD_BUFN); @@ -1376,33 +1377,33 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { check_and_update_for_bufn(lock); // If that caused us to transition to a new state, then we need to adjust to it // appropriately... - if (is_blocked(thread->second.state)) { - LOG_STATUS("WAITING", thread_id, thread->second.task_id, thread->second.state); - thread->second.before_block(); + if (is_blocked(thread->second->state)) { + LOG_STATUS("WAITING", thread_id, thread->second->task_id, thread->second->state); + thread->second->before_block(); do { - thread->second.wake_condition->wait(lock); + thread->second->wake_condition->wait(lock); thread = threads.find(thread_id); - } while (thread != threads.end() && is_blocked(thread->second.state)); - thread->second.after_block(); + } while (thread != threads.end() && is_blocked(thread->second->state)); + if (thread != threads.end()) { thread->second->after_block(); } task_has_woken_condition.notify_all(); } break; case thread_state::THREAD_SPLIT_THROW: transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.record_failed_retry_time(); + thread->second->record_failed_retry_time(); throw_split_and_retry_oom( - "rollback, split input, and retry operation", thread->second, lock); + "rollback, split input, and retry operation", *thread->second, lock); break; case thread_state::THREAD_REMOVE_THROW: LOG_TRANSITION( - thread_id, thread->second.task_id, thread->second.state, thread_state::UNKNOWN); + thread_id, thread->second->task_id, thread->second->state, thread_state::UNKNOWN); // don't need to record failed time metric the thread is already gone... threads.erase(thread); task_has_woken_condition.notify_all(); throw std::runtime_error("thread removed while blocked"); default: if (!first_time) { - LOG_STATUS("DONE WAITING", thread_id, thread->second.task_id, thread->second.state); + LOG_STATUS("DONE WAITING", thread_id, thread->second->task_id, thread->second->state); } done = true; } @@ -1425,10 +1426,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { { bool are_any_tasks_just_blocked = false; for (auto& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BLOCKED: transition(t_state, thread_state::THREAD_RUNNING); - t_state.wake_condition->notify_all(); + t_state->wake_condition->notify_all(); are_any_tasks_just_blocked = true; break; default: break; @@ -1438,14 +1439,14 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (!are_any_tasks_just_blocked) { // wake up all of the BUFN tasks. for (auto& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BUFN: // fall through case thread_state::THREAD_BUFN_THROW: // fall through case thread_state::THREAD_BUFN_WAIT: transition(t_state, thread_state::THREAD_RUNNING); - t_state.wake_condition->notify_all(); + t_state->wake_condition->notify_all(); break; default: break; } @@ -1472,12 +1473,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (remove_task_id < 0) { thread_should_be_removed = true; } else { - auto const task_id = threads_at->second.task_id; + auto const task_id = threads_at->second->task_id; if (task_id >= 0) { if (task_id == remove_task_id) { thread_should_be_removed = true; } } else { - threads_at->second.pool_task_ids.erase(remove_task_id); - if (threads_at->second.pool_task_ids.empty()) { thread_should_be_removed = true; } + threads_at->second->pool_task_ids.erase(remove_task_id); + if (threads_at->second->pool_task_ids.empty()) { thread_should_be_removed = true; } } } @@ -1492,20 +1493,20 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (task_at != task_to_threads.end()) { task_at->second.erase(thread_id); } } - switch (threads_at->second.state) { + switch (threads_at->second->state) { case thread_state::THREAD_BLOCKED: // fall through case thread_state::THREAD_BUFN: transition(threads_at->second, thread_state::THREAD_REMOVE_THROW); - threads_at->second.wake_condition->notify_all(); + threads_at->second->wake_condition->notify_all(); break; case thread_state::THREAD_RUNNING: ret = true; // fall through; default: LOG_TRANSITION(thread_id, - threads_at->second.task_id, - threads_at->second.state, + threads_at->second->task_id, + threads_at->second->state, thread_state::UNKNOWN); threads.erase(threads_at); } @@ -1546,7 +1547,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { { auto const thread = threads.find(thread_id); if (thread != threads.end()) { - switch (thread->second.state) { + switch (thread->second->state) { // If the thread is in one of the ALLOC or ALLOC_FREE states, we have detected a loop // likely due to spill setup required in cuDF. We will treat this allocation differently // and skip transitions. @@ -1559,7 +1560,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { std::stringstream ss; ss << "thread " << thread_id << " is trying to do a blocking allocate while already in the state " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } @@ -1568,39 +1569,39 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { default: break; } - if (thread->second.retry_oom.matches(is_for_cpu)) { - if (thread->second.retry_oom.skip_count > 0) { - thread->second.retry_oom.skip_count--; - } else if (thread->second.retry_oom.hit_count > 0) { - thread->second.retry_oom.hit_count--; - thread->second.metrics.num_times_retry_throw++; + if (thread->second->retry_oom.matches(is_for_cpu)) { + if (thread->second->retry_oom.skip_count > 0) { + thread->second->retry_oom.skip_count--; + } else if (thread->second->retry_oom.hit_count > 0) { + thread->second->retry_oom.hit_count--; + thread->second->metrics.num_times_retry_throw++; std::string const op_prefix = "INJECTED_RETRY_OOM_"; std::string const op = op_prefix + (is_for_cpu ? "CPU" : "GPU"); - LOG_STATUS(op, thread_id, thread->second.task_id, thread->second.state); - thread->second.record_failed_retry_time(); + LOG_STATUS(op, thread_id, thread->second->task_id, thread->second->state); + thread->second->record_failed_retry_time(); throw_java_exception(is_for_cpu ? CPU_RETRY_OOM_CLASS : GPU_RETRY_OOM_CLASS, "injected RetryOOM"); } } - if (thread->second.cudf_exception_injected > 0) { - thread->second.cudf_exception_injected--; + if (thread->second->cudf_exception_injected > 0) { + thread->second->cudf_exception_injected--; LOG_STATUS( - "INJECTED_CUDF_EXCEPTION", thread_id, thread->second.task_id, thread->second.state); - thread->second.record_failed_retry_time(); + "INJECTED_CUDF_EXCEPTION", thread_id, thread->second->task_id, thread->second->state); + thread->second->record_failed_retry_time(); throw_java_exception(cudf::jni::CUDF_EXCEPTION_CLASS, "injected CudfException"); } - if (thread->second.split_and_retry_oom.matches(is_for_cpu)) { - if (thread->second.split_and_retry_oom.skip_count > 0) { - thread->second.split_and_retry_oom.skip_count--; - } else if (thread->second.split_and_retry_oom.hit_count > 0) { - thread->second.split_and_retry_oom.hit_count--; - thread->second.metrics.num_times_split_retry_throw++; + if (thread->second->split_and_retry_oom.matches(is_for_cpu)) { + if (thread->second->split_and_retry_oom.skip_count > 0) { + thread->second->split_and_retry_oom.skip_count--; + } else if (thread->second->split_and_retry_oom.hit_count > 0) { + thread->second->split_and_retry_oom.hit_count--; + thread->second->metrics.num_times_split_retry_throw++; std::string const op_prefix = "INJECTED_SPLIT_AND_RETRY_OOM_"; std::string const op = op_prefix + (is_for_cpu ? "CPU" : "GPU"); - LOG_STATUS(op, thread_id, thread->second.task_id, thread->second.state); - thread->second.record_failed_retry_time(); + LOG_STATUS(op, thread_id, thread->second->task_id, thread->second->state); + thread->second->record_failed_retry_time(); if (is_for_cpu) { throw_java_exception(CPU_SPLIT_AND_RETRY_OOM_CLASS, "injected SplitAndRetryOOM"); } else { @@ -1611,15 +1612,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (blocking) { block_thread_until_ready(thread_id, lock); } - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_RUNNING: transition(thread->second, thread_state::THREAD_ALLOC); - thread->second.is_cpu_alloc = is_for_cpu; + thread->second->is_cpu_alloc = is_for_cpu; break; default: { std::stringstream ss; ss << "thread " << thread_id << " in unexpected state pre alloc " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } @@ -1657,41 +1658,41 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const thread = threads.find(thread_id); if (!was_recursive && thread != threads.end()) { // The allocation succeeded so we are no longer doing a retry - if (thread->second.is_retry_alloc_before_bufn) { - thread->second.is_retry_alloc_before_bufn = false; + if (thread->second->is_retry_alloc_before_bufn) { + thread->second->is_retry_alloc_before_bufn = false; LOG_STATUS( "DETAIL", thread_id, - thread->second.task_id, - thread->second.state, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_success_core", thread_id); } - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_ALLOC: // fall through case thread_state::THREAD_ALLOC_FREE: - if (thread->second.is_cpu_alloc != is_for_cpu) { + if (thread->second->is_cpu_alloc != is_for_cpu) { std::stringstream ss; ss << "thread " << thread_id << " has a mismatch on CPU vs GPU post alloc " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.is_cpu_alloc = false; + thread->second->is_cpu_alloc = false; // num_bytes is likely not padded, which could cause slight inaccuracies // but for now it shouldn't matter for watermark purposes if (!is_for_cpu) { - if (!thread->second.is_in_spilling) { - thread->second.metrics.gpu_memory_active_footprint += num_bytes; - thread->second.metrics.gpu_memory_max_footprint = - std::max(thread->second.metrics.gpu_memory_active_footprint, - thread->second.metrics.gpu_memory_max_footprint); + if (!thread->second->is_in_spilling) { + thread->second->metrics.gpu_memory_active_footprint += num_bytes; + thread->second->metrics.gpu_memory_max_footprint = + std::max(thread->second->metrics.gpu_memory_active_footprint, + thread->second->metrics.gpu_memory_max_footprint); } gpu_memory_allocated_bytes += num_bytes; - thread->second.metrics.gpu_max_memory_allocated = - std::max(thread->second.metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); + thread->second->metrics.gpu_max_memory_allocated = std::max( + thread->second->metrics.gpu_max_memory_allocated, gpu_memory_allocated_bytes); } break; default: break; @@ -1717,9 +1718,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread_priority to_wake(-1, -1); bool is_to_wake_set = false; for (auto const& [thread_d, t_state] : threads) { - thread_state const& state = t_state.state; - if (state == thread_state::THREAD_BLOCKED && is_for_cpu == t_state.is_cpu_alloc) { - thread_priority current = t_state.priority(); + thread_state const& state = t_state->state; + if (state == thread_state::THREAD_BLOCKED && is_for_cpu == t_state->is_cpu_alloc) { + thread_priority current = t_state->priority(); if (!is_to_wake_set || to_wake < current) { to_wake = current; is_to_wake_set = true; @@ -1731,15 +1732,15 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { if (thread_id_to_wake > 0) { auto const thread = threads.find(thread_id_to_wake); if (thread != threads.end()) { - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_BLOCKED: transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); break; default: { std::stringstream ss; ss << "internal error expected to only wake up blocked threads " << thread_id_to_wake - << " " << as_str(thread->second.state); + << " " << as_str(thread->second->state); throw std::runtime_error(ss.str()); } } @@ -1765,10 +1766,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread_priority to_wake(-1, -1); bool is_to_wake_set = false; for (auto const& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BUFN: { - if (is_for_cpu == t_state.is_cpu_alloc) { - thread_priority current = t_state.priority(); + if (is_for_cpu == t_state->is_cpu_alloc) { + thread_priority current = t_state->priority(); if (!is_to_wake_set || to_wake < current) { to_wake = current; is_to_wake_set = true; @@ -1787,10 +1788,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const this_id = static_cast(pthread_self()); auto const thread = threads.find(thread_id_to_wake); if (thread != threads.end() && thread->first != this_id) { - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_BUFN: transition(thread->second, thread_state::THREAD_RUNNING); - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); break; case thread_state::THREAD_BUFN_WAIT: transition(thread->second, thread_state::THREAD_RUNNING); @@ -1804,7 +1805,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { default: { std::stringstream ss; ss << "internal error expected to only wake up blocked threads " - << thread_id_to_wake << " " << as_str(thread->second.state); + << thread_id_to_wake << " " << as_str(thread->second->state); throw std::runtime_error(ss.str()); } } @@ -1815,13 +1816,13 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } } - bool is_thread_bufn_or_above(JNIEnv* env, full_thread_state const& state) + bool is_thread_bufn_or_above(JNIEnv* env, std::shared_ptr state) { bool ret = false; - if (state.pool_blocked) { + if (state->pool_blocked) { ret = true; } else { - switch (state.state) { + switch (state->state) { case thread_state::THREAD_BLOCKED: ret = false; break; case thread_state::THREAD_BUFN: // empty we are looking for even a single thread that is not blocked @@ -1829,7 +1830,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { break; default: ret = env->CallStaticBooleanMethod( - ThreadStateRegistry_jclass, isThreadBlocked_method, state.thread_id); + ThreadStateRegistry_jclass, isThreadBlocked_method, state->thread_id); break; } } @@ -1894,12 +1895,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // We are going to do two passes through the threads to deal with this. // First pass is to look at the dedicated task threads for (auto const& [thread_id, t_state] : threads) { - long const task_id = t_state.task_id; + long const task_id = t_state->task_id; if (task_id >= 0) { all_task_ids.insert(task_id); bool const is_bufn_plus = is_thread_bufn_or_above(env, t_state); if (is_bufn_plus) { bufn_task_ids.insert(task_id); } - if (is_bufn_plus || t_state.state == thread_state::THREAD_BLOCKED) { + if (is_bufn_plus || t_state->state == thread_state::THREAD_BLOCKED) { blocked_task_ids.insert(task_id); } } @@ -1907,9 +1908,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // Second pass is to look at the pool threads for (auto const& [thread_id, t_state] : threads) { - long const is_pool_thread = t_state.task_id < 0; + long const is_pool_thread = t_state->task_id < 0; if (is_pool_thread) { - for (auto const& task_id : t_state.pool_task_ids) { + for (auto const& task_id : t_state->pool_task_ids) { auto const it = pool_task_thread_count.find(task_id); if (it != pool_task_thread_count.end()) { it->second += 1; @@ -1920,7 +1921,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { bool const is_bufn_plus = is_thread_bufn_or_above(env, t_state); if (is_bufn_plus) { - for (auto const& task_id : t_state.pool_task_ids) { + for (auto const& task_id : t_state->pool_task_ids) { auto const it = pool_bufn_task_thread_count.find(task_id); if (it != pool_bufn_task_thread_count.end()) { it->second += 1; @@ -1929,8 +1930,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { } } } - if (!is_bufn_plus && t_state.state != thread_state::THREAD_BLOCKED) { - for (auto const& task_id : t_state.pool_task_ids) { + if (!is_bufn_plus && t_state->state != thread_state::THREAD_BLOCKED) { + for (auto const& task_id : t_state->pool_task_ids) { blocked_task_ids.erase(task_id); } } @@ -1993,10 +1994,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { bool is_to_bufn_set = false; int blocked_thread_count = 0; for (auto const& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BLOCKED: { blocked_thread_count++; - thread_priority const& current = t_state.priority(); + thread_priority const& current = t_state->priority(); if (!is_to_bufn_set || current < to_bufn) { to_bufn = current; is_to_bufn_set = true; @@ -2016,11 +2017,11 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // But we are not tracking when data is made spillable // so if data was made spillable we will retry the // allocation, instead of going to BUFN. - thread->second.is_retry_alloc_before_bufn = true; + thread->second->is_retry_alloc_before_bufn = true; LOG_STATUS("DETAIL", thread_id_to_bufn, - thread->second.task_id, - thread->second.state, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to true", thread_id_to_bufn); transition(thread->second, thread_state::THREAD_RUNNING); @@ -2028,7 +2029,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { log_all_threads_states(); transition(thread->second, thread_state::THREAD_BUFN_THROW); } - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); } } // We now need a way to detect if we need to split the input and retry. @@ -2057,9 +2058,9 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { thread_priority to_wake(-1, -1); bool is_to_wake_set = false; for (auto const& [thread_id, t_state] : threads) { - switch (t_state.state) { + switch (t_state->state) { case thread_state::THREAD_BUFN: { - thread_priority const& current = t_state.priority(); + thread_priority const& current = t_state->priority(); if (!is_to_wake_set || to_wake < current) { to_wake = current; is_to_wake_set = true; @@ -2072,7 +2073,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const found_thread = threads.find(thread_id); if (found_thread != threads.end()) { transition(found_thread->second, thread_state::THREAD_SPLIT_THROW); - found_thread->second.wake_condition->notify_all(); + found_thread->second->wake_condition->notify_all(); } } } @@ -2100,40 +2101,40 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // only retry if this was due to an out of memory exception. bool ret = true; if (!was_recursive && thread != threads.end()) { - if (thread->second.is_cpu_alloc != is_for_cpu) { + if (thread->second->is_cpu_alloc != is_for_cpu) { std::stringstream ss; ss << "thread " << thread_id << " has a mismatch on CPU vs GPU post alloc " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::invalid_argument(ss.str()); } - switch (thread->second.state) { + switch (thread->second->state) { case thread_state::THREAD_ALLOC_FREE: transition(thread->second, thread_state::THREAD_RUNNING); break; case thread_state::THREAD_ALLOC: - if (is_oom && thread->second.is_retry_alloc_before_bufn) { - if (thread->second.is_retry_alloc_before_bufn) { - thread->second.is_retry_alloc_before_bufn = false; + if (is_oom && thread->second->is_retry_alloc_before_bufn) { + if (thread->second->is_retry_alloc_before_bufn) { + thread->second->is_retry_alloc_before_bufn = false; LOG_STATUS( "DETAIL", thread_id, - thread->second.task_id, - thread->second.state, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } transition(thread->second, thread_state::THREAD_BUFN_THROW); - thread->second.wake_condition->notify_all(); + thread->second->wake_condition->notify_all(); } else if (is_oom && blocking) { - if (thread->second.is_retry_alloc_before_bufn) { - thread->second.is_retry_alloc_before_bufn = false; + if (thread->second->is_retry_alloc_before_bufn) { + thread->second->is_retry_alloc_before_bufn = false; LOG_STATUS( "DETAIL", thread_id, - thread->second.task_id, - thread->second.state, + thread->second->task_id, + thread->second->state, "thread (id: {}) is_retry_alloc_before_bufn set to false in post_alloc_failed_core", thread_id); } @@ -2146,7 +2147,7 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { default: { std::stringstream ss; ss << "Internal error: unexpected state after alloc failed " << thread_id << " " - << as_str(thread->second.state); + << as_str(thread->second->state); throw std::runtime_error(ss.str()); } } @@ -2188,10 +2189,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { auto const tid = static_cast(pthread_self()); auto const thread = threads.find(tid); if (thread != threads.end()) { - LOG_STATUS("DEALLOC", tid, thread->second.task_id, thread->second.state); + LOG_STATUS("DEALLOC", tid, thread->second->task_id, thread->second->state); if (!is_for_cpu) { - if (!thread->second.is_in_spilling) { - thread->second.metrics.gpu_memory_active_footprint -= num_bytes; + if (!thread->second->is_in_spilling) { + thread->second->metrics.gpu_memory_active_footprint -= num_bytes; } gpu_memory_allocated_bytes -= num_bytes; } @@ -2211,10 +2212,10 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource { // By not changing our thread's state to THREAD_ALLOC_FREE, we keep the state // the same, but we still let other threads know that there was a free and they should // handle accordingly. - if (t_state.thread_id != tid) { - switch (t_state.state) { + if (t_state->thread_id != tid) { + switch (t_state->state) { case thread_state::THREAD_ALLOC: - if (is_for_cpu == t_state.is_cpu_alloc) { + if (is_for_cpu == t_state->is_cpu_alloc) { transition(t_state, thread_state::THREAD_ALLOC_FREE); } break; diff --git a/thirdparty/cudf b/thirdparty/cudf index bd2af5e9fd..3745c9f133 160000 --- a/thirdparty/cudf +++ b/thirdparty/cudf @@ -1 +1 @@ -Subproject commit bd2af5e9fd73cb861caffe482069fd714e5cd13a +Subproject commit 3745c9f133eb3d204810090c05a5701643028c1b