Skip to content

Commit 2bfb538

Browse files
committed
add function ensureCollectTraceDone to wait and cleanup collectTraceThread
1 parent f64fd39 commit 2bfb538

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

libkineto/src/CuptiActivityProfiler.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,22 @@ void CuptiActivityProfiler::toggleCollectionDynamic(const bool enable){
11131113
#endif
11141114
}
11151115

1116+
void CuptiActivityProfiler::ensureCollectTraceDone() {
1117+
if (collectTraceThread && collectTraceThread->joinable()) {
1118+
std::lock_guard<std::mutex> guard(mutex_);
1119+
collectTraceThread->join();
1120+
collectTraceThread.reset(nullptr);
1121+
}
1122+
#endif
1123+
#ifdef HAS_ROCTRACER
1124+
if (enable) {
1125+
cupti_.enableActivities(derivedConfig_->profileActivityTypes());
1126+
} else {
1127+
cupti_.disableActivities(derivedConfig_->profileActivityTypes());
1128+
}
1129+
#endif
1130+
}
1131+
11161132
void CuptiActivityProfiler::ensureCollectTraceDone() {
11171133
if (collectTraceThread && collectTraceThread->joinable()) {
11181134
std::lock_guard<std::mutex> guard(mutex_);
@@ -1204,6 +1220,8 @@ const time_point<system_clock> CuptiActivityProfiler::performRunLoopStep(
12041220
std::lock_guard<std::mutex> guard(mutex_);
12051221
stopTraceInternal(now);
12061222
resetInternal();
1223+
LOG(ERROR) << "State: Warmup stopped by CUPTI. (Buffer size configured is " << config_->activitiesMaxGpuBufferSize() / 1024 / 1024 << "MB)";
1224+
UST_LOGGER_MARK_COMPLETED(kWarmUpStage);
12071225
VLOG(0) << "Warmup -> WaitForRequest";
12081226
break;
12091227
}
@@ -1298,6 +1316,12 @@ const time_point<system_clock> CuptiActivityProfiler::performRunLoopStep(
12981316
return new_wakeup_time;
12991317
}
13001318

1319+
// Before processing, we should wait for collectTrace thread to be done.
1320+
if (collectTraceThread && collectTraceThread->joinable()) {
1321+
std::lock_guard<std::mutex> guard(mutex_);
1322+
collectTraceThread->join();
1323+
collectTraceThread.reset(nullptr);
1324+
}
13011325
// Before processing, we should wait for collectTrace thread to be done.
13021326
if (collectTraceThread && collectTraceThread->joinable()) {
13031327
std::lock_guard<std::mutex> guard(mutex_);

0 commit comments

Comments
 (0)