Skip to content

Commit 6b0df7a

Browse files
committed
remove use after free in bli_l3_sup_thrinfo_update
1 parent b5d5783 commit 6b0df7a

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

frame/3/bli_l3_sup_int.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ err_t bli_gemmsup_int
139139
// new ways of parallelism value for the jc loop.
140140
rntm_t rntm_l = *rntm;
141141
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
142-
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
142+
bli_l3_sup_thrinfo_update( &rntm_l, thread );
143143
}
144144

145145

@@ -205,7 +205,7 @@ err_t bli_gemmsup_int
205205
// new ways of parallelism value for the jc loop.
206206
rntm_t rntm_l = *rntm;
207207
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
208-
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
208+
bli_l3_sup_thrinfo_update( &rntm_l, thread );
209209
}
210210

211211

@@ -315,7 +315,7 @@ err_t bli_gemmtsup_int
315315
// new ways of parallelism value for the jc loop.
316316
rntm_t rntm_l = *rntm;
317317
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
318-
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
318+
bli_l3_sup_thrinfo_update( &rntm_l, thread );
319319
}
320320

321321

@@ -385,7 +385,7 @@ err_t bli_gemmtsup_int
385385
// new ways of parallelism value for the jc loop.
386386
rntm_t rntm_l = *rntm;
387387
bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, &rntm_l );
388-
bli_l3_sup_thrinfo_update( &rntm_l, &thread );
388+
bli_l3_sup_thrinfo_update( &rntm_l, thread );
389389
}
390390

391391

frame/3/bli_l3_thrinfo.c

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,48 @@ thrinfo_t* bli_l3_sup_thrinfo_create
147147
void bli_l3_sup_thrinfo_update
148148
(
149149
const rntm_t* rntm,
150-
thrinfo_t** root
150+
thrinfo_t* root
151151
)
152152
{
153-
thrcomm_t* gl_comm = bli_thrinfo_comm( *root );
154-
dim_t tid = bli_thrinfo_thread_id( *root );
155-
pool_t* sba_pool = bli_thrinfo_sba_pool( *root );
156-
dim_t nt = bli_thrinfo_num_threads( *root );
153+
dim_t nt = bli_thrinfo_num_threads( root );
157154

158155
// Return early in single-threaded execution
159156
// since the thread control tree may not have been
160157
// allocated normally
161158
if ( nt == 1 ) return;
162159

163-
bli_thrinfo_free( *root );
164-
*root = bli_l3_sup_thrinfo_create( tid, gl_comm, sba_pool, rntm );
160+
// Do not free root thrinfo_t to avoid use-after-free.
161+
162+
// Free children of root
163+
for ( dim_t i = 0; i < BLIS_MAX_SUB_NODES; i++ )
164+
{
165+
thrinfo_t* thrinfo_sub_node = bli_thrinfo_sub_node( i, root );
166+
if ( thrinfo_sub_node != NULL )
167+
bli_thrinfo_free( thrinfo_sub_node );
168+
}
169+
170+
// Rebuild the children with the new rntm.
171+
const dim_t n_way_jc = bli_rntm_ways_for( BLIS_NC, rntm );
172+
const dim_t n_way_pc = bli_rntm_ways_for( BLIS_KC, rntm );
173+
const dim_t n_way_ic = bli_rntm_ways_for( BLIS_MC, rntm );
174+
const dim_t n_way_jr = bli_rntm_ways_for( BLIS_NR, rntm );
175+
const dim_t n_way_ir = bli_rntm_ways_for( BLIS_MR, rntm );
176+
177+
thrinfo_t* thread_jc = bli_thrinfo_split( n_way_jc, root );
178+
thrinfo_t* thread_pc = bli_thrinfo_split( n_way_pc, thread_jc );
179+
thrinfo_t* thread_pb = bli_thrinfo_split( 1, thread_pc );
180+
thrinfo_t* thread_ic = bli_thrinfo_split( n_way_ic, thread_pb );
181+
thrinfo_t* thread_pa = bli_thrinfo_split( 1, thread_ic );
182+
thrinfo_t* thread_jr = bli_thrinfo_split( n_way_jr, thread_pa );
183+
thrinfo_t* thread_ir = bli_thrinfo_split( n_way_ir, thread_jr );
184+
185+
bli_thrinfo_set_sub_node( 0, thread_jc, root );
186+
bli_thrinfo_set_sub_node( 0, thread_pc, thread_jc );
187+
bli_thrinfo_set_sub_node( 0, thread_pb, thread_pc );
188+
bli_thrinfo_set_sub_node( 0, thread_ic, thread_pb );
189+
bli_thrinfo_set_sub_node( 0, thread_pa, thread_ic );
190+
bli_thrinfo_set_sub_node( 0, thread_jr, thread_pa );
191+
bli_thrinfo_set_sub_node( 0, thread_ir, thread_jr );
165192
}
166193

167194
// -----------------------------------------------------------------------------

frame/3/bli_l3_thrinfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ thrinfo_t* bli_l3_sup_thrinfo_create
9797
void bli_l3_sup_thrinfo_update
9898
(
9999
const rntm_t* rntm,
100-
thrinfo_t** root
100+
thrinfo_t* root
101101
);
102102

103103
void bli_l3_thrinfo_print_gemm_paths

0 commit comments

Comments
 (0)