diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index ffba1d6612..af48d2ef81 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -139,7 +139,7 @@ err_t bli_gemmsup_int // new ways of parallelism value for the jc loop. rntm_t rntm_l = *rntm; bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l ); - bli_l3_sup_thrinfo_update( &rntm_l, &thread ); + bli_l3_sup_thrinfo_update( &rntm_l, thread ); } @@ -205,7 +205,7 @@ err_t bli_gemmsup_int // new ways of parallelism value for the jc loop. rntm_t rntm_l = *rntm; bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l ); - bli_l3_sup_thrinfo_update( &rntm_l, &thread ); + bli_l3_sup_thrinfo_update( &rntm_l, thread ); } @@ -315,7 +315,7 @@ err_t bli_gemmtsup_int // new ways of parallelism value for the jc loop. rntm_t rntm_l = *rntm; bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l ); - bli_l3_sup_thrinfo_update( &rntm_l, &thread ); + bli_l3_sup_thrinfo_update( &rntm_l, thread ); } @@ -385,7 +385,7 @@ err_t bli_gemmtsup_int // new ways of parallelism value for the jc loop. rntm_t rntm_l = *rntm; bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l ); - bli_l3_sup_thrinfo_update( &rntm_l, &thread ); + bli_l3_sup_thrinfo_update( &rntm_l, thread ); } diff --git a/frame/3/bli_l3_thrinfo.c b/frame/3/bli_l3_thrinfo.c index 7689686023..8f2feac58a 100644 --- a/frame/3/bli_l3_thrinfo.c +++ b/frame/3/bli_l3_thrinfo.c @@ -147,21 +147,51 @@ thrinfo_t* bli_l3_sup_thrinfo_create void bli_l3_sup_thrinfo_update ( const rntm_t* rntm, - thrinfo_t** root + thrinfo_t* root ) { - thrcomm_t* gl_comm = bli_thrinfo_comm( *root ); - dim_t tid = bli_thrinfo_thread_id( *root ); - pool_t* sba_pool = bli_thrinfo_sba_pool( *root ); - dim_t nt = bli_thrinfo_num_threads( *root ); + dim_t nt = bli_thrinfo_num_threads( root ); // Return early in single-threaded execution // since the thread control tree may not have been // allocated normally if ( nt == 1 ) return; - bli_thrinfo_free( *root ); - *root = bli_l3_sup_thrinfo_create( tid, gl_comm, sba_pool, rntm ); + // Do not free root thrinfo_t to avoid use-after-free. + + // Free children of root + for ( dim_t i = 0; i < BLIS_MAX_SUB_NODES; i++ ) + { + thrinfo_t* thrinfo_sub_node = bli_thrinfo_sub_node( i, root ); + if ( thrinfo_sub_node != NULL ) + bli_thrinfo_free( thrinfo_sub_node ); + + // replace freed children with NULL. + bli_thrinfo_set_sub_node( i, NULL, root ); + } + + // Rebuild the children with the new rntm. + const dim_t n_way_jc = bli_rntm_ways_for( BLIS_NC, rntm ); + const dim_t n_way_pc = bli_rntm_ways_for( BLIS_KC, rntm ); + const dim_t n_way_ic = bli_rntm_ways_for( BLIS_MC, rntm ); + const dim_t n_way_jr = bli_rntm_ways_for( BLIS_NR, rntm ); + const dim_t n_way_ir = bli_rntm_ways_for( BLIS_MR, rntm ); + + thrinfo_t* thread_jc = bli_thrinfo_split( n_way_jc, root ); + thrinfo_t* thread_pc = bli_thrinfo_split( n_way_pc, thread_jc ); + thrinfo_t* thread_pb = bli_thrinfo_split( 1, thread_pc ); + thrinfo_t* thread_ic = bli_thrinfo_split( n_way_ic, thread_pb ); + thrinfo_t* thread_pa = bli_thrinfo_split( 1, thread_ic ); + thrinfo_t* thread_jr = bli_thrinfo_split( n_way_jr, thread_pa ); + thrinfo_t* thread_ir = bli_thrinfo_split( n_way_ir, thread_jr ); + + bli_thrinfo_set_sub_node( 0, thread_jc, root ); + bli_thrinfo_set_sub_node( 0, thread_pc, thread_jc ); + bli_thrinfo_set_sub_node( 0, thread_pb, thread_pc ); + bli_thrinfo_set_sub_node( 0, thread_ic, thread_pb ); + bli_thrinfo_set_sub_node( 0, thread_pa, thread_ic ); + bli_thrinfo_set_sub_node( 0, thread_jr, thread_pa ); + bli_thrinfo_set_sub_node( 0, thread_ir, thread_jr ); } // ----------------------------------------------------------------------------- diff --git a/frame/3/bli_l3_thrinfo.h b/frame/3/bli_l3_thrinfo.h index b041ac993c..4e52b44287 100644 --- a/frame/3/bli_l3_thrinfo.h +++ b/frame/3/bli_l3_thrinfo.h @@ -97,7 +97,7 @@ thrinfo_t* bli_l3_sup_thrinfo_create void bli_l3_sup_thrinfo_update ( const rntm_t* rntm, - thrinfo_t** root + thrinfo_t* root ); void bli_l3_thrinfo_print_gemm_paths