Skip to content

Commit

Permalink
add function ensureCollectTraceDone to wait and cleanup collectTraceT…
Browse files Browse the repository at this point in the history
…hread
  • Loading branch information
staugust committed Aug 27, 2024
1 parent f64fd39 commit 2bfb538
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions libkineto/src/CuptiActivityProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,22 @@ void CuptiActivityProfiler::toggleCollectionDynamic(const bool enable){
#endif
}

void CuptiActivityProfiler::ensureCollectTraceDone() {
if (collectTraceThread && collectTraceThread->joinable()) {
std::lock_guard<std::mutex> guard(mutex_);
collectTraceThread->join();
collectTraceThread.reset(nullptr);
}
#endif
#ifdef HAS_ROCTRACER
if (enable) {
cupti_.enableActivities(derivedConfig_->profileActivityTypes());
} else {
cupti_.disableActivities(derivedConfig_->profileActivityTypes());
}
#endif
}

void CuptiActivityProfiler::ensureCollectTraceDone() {
if (collectTraceThread && collectTraceThread->joinable()) {
std::lock_guard<std::mutex> guard(mutex_);
Expand Down Expand Up @@ -1204,6 +1220,8 @@ const time_point<system_clock> CuptiActivityProfiler::performRunLoopStep(
std::lock_guard<std::mutex> guard(mutex_);
stopTraceInternal(now);
resetInternal();
LOG(ERROR) << "State: Warmup stopped by CUPTI. (Buffer size configured is " << config_->activitiesMaxGpuBufferSize() / 1024 / 1024 << "MB)";
UST_LOGGER_MARK_COMPLETED(kWarmUpStage);
VLOG(0) << "Warmup -> WaitForRequest";
break;
}
Expand Down Expand Up @@ -1298,6 +1316,12 @@ const time_point<system_clock> CuptiActivityProfiler::performRunLoopStep(
return new_wakeup_time;
}

// Before processing, we should wait for collectTrace thread to be done.
if (collectTraceThread && collectTraceThread->joinable()) {
std::lock_guard<std::mutex> guard(mutex_);
collectTraceThread->join();
collectTraceThread.reset(nullptr);
}
// Before processing, we should wait for collectTrace thread to be done.
if (collectTraceThread && collectTraceThread->joinable()) {
std::lock_guard<std::mutex> guard(mutex_);
Expand Down

0 comments on commit 2bfb538

Please sign in to comment.