@@ -553,15 +553,29 @@ void TileTransposeTactic::TileCacheBlock(ir::IRSchedule* sch,
553553
554554 // Step 3. Do inner-block transpose.
555555 int offset = high_axis_.size ();
556- sch->Split (shared_cache_block_id, offset + 1 , {-1 , 4 , 8 });
556+ #ifdef CINN_WITH_CUSTOM_DEVICE
557+ sch->Split (
558+ shared_cache_block_id,
559+ offset + 1 ,
560+ {-1 , static_cast <int >(context_->config .tile_config .warp_size / 8 ), 8 });
557561 sch->Split (shared_cache_block_id,
558562 offset,
559563 {-1 , static_cast <int >(context_->config .tile_config .warp_size )});
560564
561565 sch->Split (local_cache_block_id,
562566 offset + 1 ,
563567 {-1 , static_cast <int >(context_->config .tile_config .warp_size )});
568+ sch->Split (
569+ local_cache_block_id,
570+ offset,
571+ {-1 , static_cast <int >(context_->config .tile_config .warp_size / 8 ), 8 });
572+ #else // CINN_WITH_CUDA
573+ sch->Split (shared_cache_block_id, offset + 1 , {-1 , 4 , 8 });
574+ sch->Split (shared_cache_block_id, offset, {-1 , 32 });
575+
576+ sch->Split (local_cache_block_id, offset + 1 , {-1 , 32 });
564577 sch->Split (local_cache_block_id, offset, {-1 , 4 , 8 });
578+ #endif
565579
566580 sch->Reorder (shared_cache_block_id, OffsetVec ({0 , 2 , 3 , 4 , 1 }, offset));
567581 sch->Reorder (local_cache_block_id, OffsetVec ({0 , 3 , 1 , 2 , 4 }, offset));
@@ -576,10 +590,18 @@ void TileTransposeTactic::TileBlock(ir::IRSchedule* sch,
576590 CanonicalizeLayout (sch, block_id);
577591
578592 int offset = high_axis_.size ();
593+ #ifdef CINN_WITH_CUSTOM_DEVICE
579594 sch->Split (block_id,
580595 offset + 1 ,
581596 {-1 , static_cast <int >(context_->config .tile_config .warp_size )});
597+ sch->Split (
598+ block_id,
599+ offset,
600+ {-1 , static_cast <int >(context_->config .tile_config .warp_size / 8 ), 8 });
601+ #else // CINN_WITH_CUDA
602+ sch->Split (block_id, offset + 1 , {-1 , 32 });
582603 sch->Split (block_id, offset, {-1 , 4 , 8 });
604+ #endif
583605
584606 sch->Reorder (block_id, OffsetVec ({0 , 3 , 1 , 2 , 4 }, offset));
585607
0 commit comments