@@ -238,128 +238,110 @@ struct CPUStreamsExecutor::Impl {
238238 // it's only a workaround for ticket CVS-111490, please be carefully when need to modify
239239 // CustomeThreadLocal::local(), especially like operations that will affect the count of
240240 // CustomThreadLocal::ThreadId
241- class CustomThreadLocal : public ThreadLocal <std::shared_ptr<Stream>> {
242- class ThreadTracker {
243- public:
244- explicit ThreadTracker (const std::thread::id& id)
245- : _id(id),
246- _count_ptr(std::make_shared<std::atomic_int>(1 )) {}
247- ~ThreadTracker () {
248- _count_ptr->fetch_sub (1 );
249- }
250- std::shared_ptr<ThreadTracker> fetch () {
251- auto new_ptr = std::shared_ptr<ThreadTracker>(new ThreadTracker (*this ));
252- auto pre_valule = new_ptr.get ()->_count_ptr ->fetch_add (1 );
253- OPENVINO_ASSERT (pre_valule == 1 , " this value must be 1, please check code CustomThreadLocal::local()" );
254- return new_ptr;
241+ class CustomThreadLocal : public ThreadLocal <std::shared_ptr<Stream>>,
242+ public std::enable_shared_from_this<CustomThreadLocal> {
243+ public:
244+ CustomThreadLocal (std::function<std::shared_ptr<Stream>()> callback_construct, Impl* impl)
245+ : ThreadLocal<std::shared_ptr<Stream>>(std::move(callback_construct)),
246+ _impl (impl) {
247+ _executor_thread_id = std::this_thread::get_id ();
248+ }
249+ void cleanup (std::thread::id thread_id) {
250+ std::lock_guard<std::mutex> guard (_stream_map_mutex);
251+ auto item = _stream_map.find (thread_id);
252+ if (item != _stream_map.end ()) {
253+ _stream_map.erase (item);
255254 }
256- const std::thread::id& get_id () const {
257- return _id;
255+ }
256+ struct ThreadCleaner {
257+ struct ResourceKeeper {
258+ // The crt call the atexit released the global variable before thread exit
259+ // Need to call the map's clear to set the size to 0
260+ // Or the thread exit callback will double free the map's content
261+ ~ResourceKeeper () {
262+ std::lock_guard<std::mutex> lock (_mutex);
263+ _mapdata.clear ();
264+ }
265+ std::mutex _mutex;
266+ std::multimap<std::thread::id, std::weak_ptr<CustomThreadLocal>> _mapdata;
267+ };
268+ static ResourceKeeper global_resource_holder;
269+
270+ ThreadCleaner (std::thread::id thread_id) : _thread_id(thread_id) {}
271+ void add (std::weak_ptr<CustomThreadLocal> parent) {
272+ std::lock_guard<std::mutex> lock (global_resource_holder._mutex );
273+ global_resource_holder._mapdata .insert ({_thread_id, parent});
258274 }
259- int count () const {
260- return *(_count_ptr.get ());
275+ ~ThreadCleaner () {
276+ std::vector<std::shared_ptr<CustomThreadLocal>> parent_ptr;
277+ {
278+ // Releaese the global lock as soon as possible
279+ std::lock_guard<std::mutex> lock (global_resource_holder._mutex );
280+ auto range = global_resource_holder._mapdata .equal_range (_thread_id);
281+ for (auto it = range.first ; it != range.second ; ++it) {
282+ parent_ptr.push_back (it->second .lock ());
283+ }
284+ if (!parent_ptr.empty ()) {
285+ global_resource_holder._mapdata .erase (range.first , range.second );
286+ }
287+ }
288+ for (auto & parent : parent_ptr) {
289+ if (parent) {
290+ parent->cleanup (_thread_id);
291+ }
292+ }
261293 }
262294
263295 private:
264- // disable all copy and move semantics, user only can use fetch()
265- // to create a new instance with a shared count num;
266- ThreadTracker (const ThreadTracker&) = default ;
267- ThreadTracker (ThreadTracker&&) = delete ;
268- ThreadTracker& operator =(const ThreadTracker&) = delete ;
269- ThreadTracker& operator =(ThreadTracker&&) = delete ;
270- std::thread::id _id;
271- std::shared_ptr<std::atomic_int> _count_ptr;
296+ std::thread::id _thread_id;
272297 };
273-
274- public:
275- CustomThreadLocal (std::function<std::shared_ptr<Stream>()> callback_construct, Impl* impl)
276- : ThreadLocal<std::shared_ptr<Stream>>(std::move(callback_construct)),
277- _impl (impl) {}
278298 std::shared_ptr<Stream> local () {
279- // maybe there are two CPUStreamsExecutors in the same thread.
280- static thread_local std::map<void *, std::shared_ptr<CustomThreadLocal::ThreadTracker>> t_stream_count_map;
281- // fix the memory leak issue that CPUStreamsExecutor is already released,
282- // but still exists CustomThreadLocal::ThreadTracker in t_stream_count_map
283- for (auto it = t_stream_count_map.begin (); it != t_stream_count_map.end ();) {
284- if (this != it->first && it->second ->count () == 1 ) {
285- t_stream_count_map.erase (it++);
286- } else {
287- it++;
288- }
289- }
290299 auto id = std::this_thread::get_id ();
291- auto search = _thread_ids.find (id);
292- if (search != _thread_ids.end ()) {
300+ if (id == _executor_thread_id) {
293301 return ThreadLocal<std::shared_ptr<Stream>>::local ();
294302 }
295- std::lock_guard<std::mutex> guard (_stream_map_mutex);
296- for (auto & item : _stream_map) {
297- if (item.first ->get_id () == id) {
298- // check if the ThreadTracker of this stream is already in t_stream_count_map
299- // if not, then create ThreadTracker for it
300- auto iter = t_stream_count_map.find ((void *)this );
301- if (iter == t_stream_count_map.end ()) {
302- t_stream_count_map[(void *)this ] = item.first ->fetch ();
303- }
304- return item.second ;
305- }
306- }
307- std::shared_ptr<Impl::Stream> stream = nullptr ;
308- for (auto it = _stream_map.begin (); it != _stream_map.end ();) {
309- if (it->first ->count () == 1 ) {
310- if (stream == nullptr ) {
311- stream = it->second ;
312- }
313- _stream_map.erase (it++);
314- } else {
315- it++;
303+
304+ // ensure ThreadCleaner is created only once per thread exit
305+ thread_local ThreadCleaner t_cleaner (id);
306+ {
307+ std::lock_guard<std::mutex> guard (_stream_map_mutex);
308+ auto search = _stream_map.find (id);
309+ if (search != _stream_map.end ()) {
310+ return search->second ;
316311 }
317312 }
318- if (stream == nullptr ) {
319- stream = std::make_shared<Impl::Stream>(_impl);
313+ std::shared_ptr<Impl::Stream> stream = std::make_shared<Impl::Stream>(_impl);
314+ t_cleaner.add (this ->shared_from_this ());
315+ {
316+ std::lock_guard<std::mutex> guard (_stream_map_mutex);
317+ _stream_map[id] = stream;
320318 }
321- auto tracker_ptr = std::make_shared<CustomThreadLocal::ThreadTracker>(id);
322- t_stream_count_map[(void *)this ] = tracker_ptr;
323- auto new_tracker_ptr = tracker_ptr->fetch ();
324- _stream_map[new_tracker_ptr] = stream;
325319 return stream;
326320 }
327321
328- void set_thread_ids_map (std::vector<std::thread>& threads) {
329- for (auto & thread : threads) {
330- _thread_ids.insert (thread.get_id ());
331- }
332- }
333-
334322 bool find_thread_id () {
335323 auto id = std::this_thread::get_id ();
336- auto search = _thread_ids.find (id);
337- if (search != _thread_ids.end ()) {
324+ if (id == _executor_thread_id) {
338325 return true ;
339326 }
340327 std::lock_guard<std::mutex> guard (_stream_map_mutex);
341- for (auto & item : _stream_map) {
342- if (item.first ->get_id () == id) {
343- return true ;
344- }
345- }
346- return false ;
328+ auto item = _stream_map.find (id);
329+ return item != _stream_map.end ();
347330 }
348331
349332 private:
350- std::set<std::thread::id> _thread_ids;
351333 Impl* _impl;
352- std::map<std::shared_ptr<CustomThreadLocal::ThreadTracker> , std::shared_ptr<Impl::Stream>> _stream_map;
334+ std::map<std::thread::id , std::shared_ptr<Impl::Stream>> _stream_map;
353335 std::mutex _stream_map_mutex;
336+ std::thread::id _executor_thread_id;
354337 };
355338
356- explicit Impl (const Config& config)
357- : _config{config},
358- _streams (
359- [this ] {
360- return std::make_shared<Impl::Stream>(this );
361- },
362- this ) {
339+ explicit Impl (const Config& config) : _config{config} {
340+ _streams = std::make_shared<CustomThreadLocal>(
341+ [this ] {
342+ return std::make_shared<Impl::Stream>(this );
343+ },
344+ this );
363345 auto numaNodes = get_available_numa_nodes ();
364346 int streams_num = _config.get_streams ();
365347 auto processor_ids = _config.get_stream_processor_ids ();
@@ -390,12 +372,11 @@ struct CPUStreamsExecutor::Impl {
390372 }
391373 }
392374 if (task) {
393- Execute (task, *(_streams. local ()));
375+ Execute (task, *(_streams-> local ()));
394376 }
395377 }
396378 });
397379 }
398- _streams.set_thread_ids_map (_threads);
399380 }
400381
401382 void Enqueue (Task task) {
@@ -424,7 +405,7 @@ struct CPUStreamsExecutor::Impl {
424405 void pin_stream_to_cpus () {
425406#if OV_THREAD == OV_THREAD_SEQ
426407 if (_config.get_cpu_pinning ()) {
427- auto stream = _streams. local ();
408+ auto stream = _streams-> local ();
428409 auto proc_type_table = get_org_proc_type_table ();
429410 std::tie (stream->_mask , stream->_ncpus ) = get_process_mask ();
430411 if (get_num_numa_nodes () > 1 ) {
@@ -443,15 +424,15 @@ struct CPUStreamsExecutor::Impl {
443424
444425 void unpin_stream_to_cpus () {
445426#if OV_THREAD == OV_THREAD_SEQ
446- auto stream = _streams. local ();
427+ auto stream = _streams-> local ();
447428 if (stream->_mask ) {
448429 pin_current_thread_by_mask (stream->_ncpus , stream->_mask );
449430 }
450431#endif
451432 }
452433
453434 void Defer (Task task) {
454- auto & stream = *(_streams. local ());
435+ auto & stream = *(_streams-> local ());
455436 stream._taskQueue .push (std::move (task));
456437 if (!stream._execute ) {
457438 stream._execute = true ;
@@ -476,17 +457,20 @@ struct CPUStreamsExecutor::Impl {
476457 std::queue<Task> _taskQueue;
477458 bool _isStopped = false ;
478459 std::vector<int > _usedNumaNodes;
479- CustomThreadLocal _streams;
460+ std::shared_ptr< CustomThreadLocal> _streams;
480461 bool _isExit = false ;
481462 std::vector<int > _cpu_ids_all;
482463 std::mutex _cpu_ids_mutex;
483464};
484465
466+ CPUStreamsExecutor::Impl::CustomThreadLocal::ThreadCleaner::ResourceKeeper
467+ CPUStreamsExecutor::Impl::CustomThreadLocal::ThreadCleaner::global_resource_holder;
468+
485469int CPUStreamsExecutor::get_stream_id () {
486- if (!_impl->_streams . find_thread_id ()) {
470+ if (!_impl->_streams -> find_thread_id ()) {
487471 return 0 ;
488472 }
489- auto stream = _impl->_streams . local ();
473+ auto stream = _impl->_streams -> local ();
490474 return stream->_streamId ;
491475}
492476
@@ -495,23 +479,23 @@ int CPUStreamsExecutor::get_streams_num() {
495479}
496480
497481int CPUStreamsExecutor::get_numa_node_id () {
498- if (!_impl->_streams . find_thread_id ()) {
482+ if (!_impl->_streams -> find_thread_id ()) {
499483 return 0 ;
500484 }
501- auto stream = _impl->_streams . local ();
485+ auto stream = _impl->_streams -> local ();
502486 return stream->_numaNodeId ;
503487}
504488
505489int CPUStreamsExecutor::get_socket_id () {
506- if (!_impl->_streams . find_thread_id ()) {
490+ if (!_impl->_streams -> find_thread_id ()) {
507491 return 0 ;
508492 }
509- auto stream = _impl->_streams . local ();
493+ auto stream = _impl->_streams -> local ();
510494 return stream->_socketId ;
511495}
512496
513497std::vector<int > CPUStreamsExecutor::get_rank () {
514- auto stream = _impl->_streams . local ();
498+ auto stream = _impl->_streams -> local ();
515499 return stream->_rank ;
516500}
517501
0 commit comments