Skip to content

Commit 8dadc7d

Browse files
authored
CINN CustomDevice. Bugfix tile_transpose_tactic.cc analysis_predictor.cc (#78512)
1 parent 962d636 commit 8dadc7d

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

paddle/cinn/ir/group_schedule/tactic/tile_transpose_tactic.cc

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,15 +553,29 @@ void TileTransposeTactic::TileCacheBlock(ir::IRSchedule* sch,
553553

554554
// Step 3. Do inner-block transpose.
555555
int offset = high_axis_.size();
556-
sch->Split(shared_cache_block_id, offset + 1, {-1, 4, 8});
556+
#ifdef CINN_WITH_CUSTOM_DEVICE
557+
sch->Split(
558+
shared_cache_block_id,
559+
offset + 1,
560+
{-1, static_cast<int>(context_->config.tile_config.warp_size / 8), 8});
557561
sch->Split(shared_cache_block_id,
558562
offset,
559563
{-1, static_cast<int>(context_->config.tile_config.warp_size)});
560564

561565
sch->Split(local_cache_block_id,
562566
offset + 1,
563567
{-1, static_cast<int>(context_->config.tile_config.warp_size)});
568+
sch->Split(
569+
local_cache_block_id,
570+
offset,
571+
{-1, static_cast<int>(context_->config.tile_config.warp_size / 8), 8});
572+
#else // CINN_WITH_CUDA
573+
sch->Split(shared_cache_block_id, offset + 1, {-1, 4, 8});
574+
sch->Split(shared_cache_block_id, offset, {-1, 32});
575+
576+
sch->Split(local_cache_block_id, offset + 1, {-1, 32});
564577
sch->Split(local_cache_block_id, offset, {-1, 4, 8});
578+
#endif
565579

566580
sch->Reorder(shared_cache_block_id, OffsetVec({0, 2, 3, 4, 1}, offset));
567581
sch->Reorder(local_cache_block_id, OffsetVec({0, 3, 1, 2, 4}, offset));
@@ -576,10 +590,18 @@ void TileTransposeTactic::TileBlock(ir::IRSchedule* sch,
576590
CanonicalizeLayout(sch, block_id);
577591

578592
int offset = high_axis_.size();
593+
#ifdef CINN_WITH_CUSTOM_DEVICE
579594
sch->Split(block_id,
580595
offset + 1,
581596
{-1, static_cast<int>(context_->config.tile_config.warp_size)});
597+
sch->Split(
598+
block_id,
599+
offset,
600+
{-1, static_cast<int>(context_->config.tile_config.warp_size / 8), 8});
601+
#else // CINN_WITH_CUDA
602+
sch->Split(block_id, offset + 1, {-1, 32});
582603
sch->Split(block_id, offset, {-1, 4, 8});
604+
#endif
583605

584606
sch->Reorder(block_id, OffsetVec({0, 3, 1, 2, 4}, offset));
585607

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
924924
delete_assert_op_pm.Run(pir_program_.get());
925925
}
926926

927-
if (config_.use_gpu() && config_.cinn_enabled()) {
927+
if ((config_.use_gpu() || config_.use_custom_device()) &&
928+
config_.cinn_enabled()) {
928929
if (!config_.custom_pass_only_) {
929930
::pir::PassManager fused_op_pm(::pir::IrContext::Instance(),
930931
config_.pm_opt_level_);

0 commit comments

Comments
 (0)