Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions frame/3/bli_l3_sup_int.c
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ err_t bli_gemmsup_int
// new ways of parallelism value for the jc loop.
rntm_t rntm_l = *rntm;
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
bli_l3_sup_thrinfo_update( &rntm_l, thread );
}


Expand Down Expand Up @@ -205,7 +205,7 @@ err_t bli_gemmsup_int
// new ways of parallelism value for the jc loop.
rntm_t rntm_l = *rntm;
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
bli_l3_sup_thrinfo_update( &rntm_l, thread );
}


Expand Down Expand Up @@ -315,7 +315,7 @@ err_t bli_gemmtsup_int
// new ways of parallelism value for the jc loop.
rntm_t rntm_l = *rntm;
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
bli_l3_sup_thrinfo_update( &rntm_l, thread );
}


Expand Down Expand Up @@ -385,7 +385,7 @@ err_t bli_gemmtsup_int
// new ways of parallelism value for the jc loop.
rntm_t rntm_l = *rntm;
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
bli_l3_sup_thrinfo_update( &rntm_l, thread );
}


Expand Down
44 changes: 37 additions & 7 deletions frame/3/bli_l3_thrinfo.c
Original file line number Diff line number Diff line change
Expand Up @@ -147,21 +147,51 @@ thrinfo_t* bli_l3_sup_thrinfo_create
void bli_l3_sup_thrinfo_update
(
const rntm_t* rntm,
thrinfo_t** root
thrinfo_t* root
)
{
thrcomm_t* gl_comm = bli_thrinfo_comm( *root );
dim_t tid = bli_thrinfo_thread_id( *root );
pool_t* sba_pool = bli_thrinfo_sba_pool( *root );
dim_t nt = bli_thrinfo_num_threads( *root );
dim_t nt = bli_thrinfo_num_threads( root );

// Return early in single-threaded execution
// since the thread control tree may not have been
// allocated normally
if ( nt == 1 ) return;

bli_thrinfo_free( *root );
*root = bli_l3_sup_thrinfo_create( tid, gl_comm, sba_pool, rntm );
// Do not free root thrinfo_t to avoid use-after-free.

// Free children of root
for ( dim_t i = 0; i < BLIS_MAX_SUB_NODES; i++ )
{
thrinfo_t* thrinfo_sub_node = bli_thrinfo_sub_node( i, root );
if ( thrinfo_sub_node != NULL )
bli_thrinfo_free( thrinfo_sub_node );

// replace freed children with NULL.
bli_thrinfo_set_sub_node( i, NULL, root );
}

// Rebuild the children with the new rntm.
const dim_t n_way_jc = bli_rntm_ways_for( BLIS_NC, rntm );
const dim_t n_way_pc = bli_rntm_ways_for( BLIS_KC, rntm );
const dim_t n_way_ic = bli_rntm_ways_for( BLIS_MC, rntm );
const dim_t n_way_jr = bli_rntm_ways_for( BLIS_NR, rntm );
const dim_t n_way_ir = bli_rntm_ways_for( BLIS_MR, rntm );

thrinfo_t* thread_jc = bli_thrinfo_split( n_way_jc, root );
thrinfo_t* thread_pc = bli_thrinfo_split( n_way_pc, thread_jc );
thrinfo_t* thread_pb = bli_thrinfo_split( 1, thread_pc );
thrinfo_t* thread_ic = bli_thrinfo_split( n_way_ic, thread_pb );
thrinfo_t* thread_pa = bli_thrinfo_split( 1, thread_ic );
thrinfo_t* thread_jr = bli_thrinfo_split( n_way_jr, thread_pa );
thrinfo_t* thread_ir = bli_thrinfo_split( n_way_ir, thread_jr );

bli_thrinfo_set_sub_node( 0, thread_jc, root );
bli_thrinfo_set_sub_node( 0, thread_pc, thread_jc );
bli_thrinfo_set_sub_node( 0, thread_pb, thread_pc );
bli_thrinfo_set_sub_node( 0, thread_ic, thread_pb );
bli_thrinfo_set_sub_node( 0, thread_pa, thread_ic );
bli_thrinfo_set_sub_node( 0, thread_jr, thread_pa );
bli_thrinfo_set_sub_node( 0, thread_ir, thread_jr );
}

// -----------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion frame/3/bli_l3_thrinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ thrinfo_t* bli_l3_sup_thrinfo_create
void bli_l3_sup_thrinfo_update
(
const rntm_t* rntm,
thrinfo_t** root
thrinfo_t* root
);

void bli_l3_thrinfo_print_gemm_paths
Expand Down