@@ -203,6 +203,16 @@ static Task3DTile2DIndex Delinearize(size_t task_index, size_t range_i,
203203 return {task_i, offset_j, offset_k, extent_j, extent_k};
204204}
205205
206+ // XNNPACK tends to choose too small tile sizes that create too many tasks. For
207+ // dynamic versions of parallel loops we can choose tile size to be any multiple
208+ // of the original tile size. This function ensures that the tile size is at
209+ // least `min_tile_size`.
210+ static size_t AdjustTileSize (size_t tile_size, size_t min_tile_size) {
211+ size_t adjusted_tile_size = tile_size;
212+ while (adjusted_tile_size < min_tile_size) adjusted_tile_size += tile_size;
213+ return adjusted_tile_size;
214+ }
215+
206216// In the `Parallelize` implementations below:
207217//
208218// (1) If done event is already available, execute the task immediately in the
@@ -280,7 +290,7 @@ void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
280290
281291void ParallelLoopRunner::ParallelizeDynamic (size_t range, size_t tile,
282292 Task1DTile1DDynamic task) {
283- Parallelize (range, tile, std::move (task));
293+ Parallelize (range, AdjustTileSize ( tile, 128 ) , std::move (task));
284294}
285295
286296struct ParallelLoopRunner ::ParallelTask2DTile1D {
@@ -320,7 +330,7 @@ void ParallelLoopRunner::Parallelize(size_t range_i, size_t range_j,
320330void ParallelLoopRunner::ParallelizeDynamic (size_t range_i, size_t range_j,
321331 size_t tile_j,
322332 Task2DTile1DDynamic task) {
323- Parallelize (range_i, range_j, tile_j, std::move (task));
333+ Parallelize (range_i, range_j, AdjustTileSize ( tile_j, 128 ) , std::move (task));
324334}
325335
326336struct ParallelLoopRunner ::ParallelTask3DTile2D {
@@ -366,7 +376,8 @@ void ParallelLoopRunner::ParallelizeDynamic(size_t range_i, size_t range_j,
366376 size_t range_k, size_t tile_j,
367377 size_t tile_k,
368378 Task3DTile2DDynamic task) {
369- Parallelize (range_i, range_j, range_k, tile_j, tile_k, std::move (task));
379+ Parallelize (range_i, range_j, range_k, AdjustTileSize (tile_j, 128 ),
380+ AdjustTileSize (tile_k, 128 ), std::move (task));
370381}
371382
372383} // namespace xla::cpu
0 commit comments