Skip to content

Commit 06d605d

Browse files
author
chenqian
committed
[Pass] RISCVDotprodSplitter Passed
1 parent bc36edd commit 06d605d

12 files changed

+1508
-62
lines changed

llvm/lib/Target/RISCV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ add_llvm_target(RISCVCodeGen
4646
RISCVEsp32P4MemIntrin.cpp
4747
RISCVIndirectBranchTracking.cpp
4848
RISCVIntLoopUnrollAndRemainder.cpp
49+
RISCVDotprodSplitter.cpp
4950
RISCVInsertReadWriteCSR.cpp
5051
RISCVInsertVSETVLI.cpp
5152
RISCVInsertWriteVXRM.cpp

llvm/lib/Target/RISCV/RISCVDotprodSplitter.cpp

Lines changed: 1070 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===-- RISCVDotprodSplitter.h - RISC-V Dotprod Splitter Pass -*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the RISCVDotprodSplitterPass class.
10+
// This pass identifies a specific pattern often associated with calls to inner
11+
// dot product computation functions, where the result is passed via a pointer
12+
// argument (typically an alloca in the caller). The pattern involves a
13+
// sequence of lifetime start, the call instruction, a load from the result
14+
// pointer, and lifetime end, all within the same basic block.
15+
//
16+
// If this unique pattern is found, the pass restructures the control flow
17+
// graph (CFG) to create specialized paths for common constant "step" or
18+
// "stride" values (specifically 1, 2, and 3) passed as arguments to the
19+
// inner call.
20+
//
21+
// It introduces conditional branches based on the runtime values of the step
22+
// arguments. If the steps match one of the specialized constant pairs (e.g.,
23+
// image_step=1 and filter_step=1), control flows to a duplicated version of
24+
// the call sequence where the step arguments are replaced with constants.
25+
// Otherwise, control flows to the original call sequence (generic path).
26+
//
27+
// A PHI node merges the results from the specialized and generic paths.
28+
// This transformation aims to enable further optimizations like constant
29+
// propagation and function specialization for the common step values within
30+
// the called function, potentially improving performance on targets like
31+
// RISC-V with specific dot product acceleration capabilities.
32+
// The pass is controlled by the `-riscv-dotprod-splitter` command-line option.
33+
//===----------------------------------------------------------------------===//
34+
35+
#ifndef LLVM_LIB_TARGET_RISCV_RISCVDOTPRODSPLITTER_H
36+
#define LLVM_LIB_TARGET_RISCV_RISCVDOTPRODSPLITTER_H
37+
38+
#include "llvm/IR/PassManager.h"
39+
#include "llvm/Support/CommandLine.h"
40+
41+
namespace llvm {
42+
43+
class Function;
44+
class Module;
45+
46+
extern cl::opt<bool> EnableRISCVDotprodSplitter;
47+
48+
/// Pass that specializes dot product calls for common step values.
49+
///
50+
/// This pass identifies patterns where inner dot product functions are called
51+
/// with runtime step parameters and creates specialized versions for common
52+
/// constant step values (1, 2, 3) to enable better optimization.
53+
struct RISCVDotprodSplitterPass
54+
: public PassInfoMixin<RISCVDotprodSplitterPass> {
55+
56+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
57+
58+
static bool isRequired() { return true; }
59+
60+
/// Check if the function contains patterns that can be processed by this
61+
/// pass.
62+
static bool hasProcessablePattern(Function &F);
63+
};
64+
65+
/// Conditional LoopExtractor Pass that only runs when dotprod patterns exist.
66+
///
67+
/// This pass runs LoopExtractor only on modules that contain processable
68+
/// dot product patterns, avoiding unnecessary loop extraction on modules
69+
/// that won't benefit from the dotprod splitter optimization.
70+
struct RISCVConditionalLoopExtractorPass
71+
: public PassInfoMixin<RISCVConditionalLoopExtractorPass> {
72+
73+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
74+
static bool isRequired() { return true; }
75+
76+
private:
77+
bool moduleHasProcessablePatterns(Module &M);
78+
};
79+
80+
} // namespace llvm
81+
82+
#endif // LLVM_LIB_TARGET_RISCV_RISCVDOTPRODSPLITTER_H

llvm/lib/Target/RISCV/RISCVTargetMachine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "RISCVCustomLICM.h"
1717
#include "RISCVEsp32P4MemIntrin.h"
1818
#include "RISCVIntLoopUnrollAndRemainder.h"
19+
#include "RISCVDotprodSplitter.h"
1920
#include "RISCVLoopUnrollAndRemainder.h"
2021
#include "RISCVMachineFunctionInfo.h"
2122
#include "RISCVTargetObjectFile.h"
@@ -681,6 +682,10 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
681682
FPM.addPass(RISCVIntLoopUnrollAndRemainderPass());
682683
return true;
683684
}
685+
if (Name == "riscv-dotprod-splitter") {
686+
FPM.addPass(RISCVDotprodSplitterPass());
687+
return true;
688+
}
684689
return false;
685690
});
686691

llvm/test/CodeGen/RISCV/RISCVDotprodSplitter/dspi_dotprod_off_s16_ansi.ll

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2-
; RUN: opt -S -mtriple=riscv32-esp-unknown-elf < %s | FileCheck %s
2+
; RUN: opt -S -mtriple=riscv32-esp-unknown-elf -passes=riscv-dotprod-splitter -riscv-dotprod-splitter=true < %s | FileCheck %s
33

44
; Function Attrs: nofree norecurse nosync nounwind memory(read, argmem: readwrite, inaccessiblemem: none)
55
define dso_local range(i32 0, 458756) i32 @dspi_dotprod_off_s16_ansi(ptr nocapture noundef readonly %in_image, ptr nocapture noundef readonly %filter, ptr nocapture noundef writeonly %out_value, i32 noundef %count_x, i32 noundef %count_y, i32 noundef %shift, i16 noundef signext %offset) local_unnamed_addr {
66
; CHECK-LABEL: define dso_local range(i32 0, 458756) i32 @dspi_dotprod_off_s16_ansi(
7-
; CHECK-SAME: ptr nocapture noundef readonly [[IN_IMAGE:%.*]], ptr nocapture noundef readonly [[FILTER:%.*]], ptr nocapture noundef writeonly [[OUT_VALUE:%.*]], i32 noundef [[COUNT_X:%.*]], i32 noundef [[COUNT_Y:%.*]], i32 noundef [[SHIFT:%.*]], i16 noundef signext [[OFFSET:%.*]]) local_unnamed_addr {
7+
; CHECK-SAME: ptr noundef readonly captures(none) [[IN_IMAGE:%.*]], ptr noundef readonly captures(none) [[FILTER:%.*]], ptr noundef writeonly captures(none) [[OUT_VALUE:%.*]], i32 noundef [[COUNT_X:%.*]], i32 noundef [[COUNT_Y:%.*]], i32 noundef [[SHIFT:%.*]], i16 noundef signext [[OFFSET:%.*]]) local_unnamed_addr {
88
; CHECK-NEXT: [[ENTRY:.*]]:
99
; CHECK-NEXT: [[ADD38_US_LOC:%.*]] = alloca i64, align 8
1010
; CHECK-NEXT: [[STEP_X:%.*]] = getelementptr inbounds i8, ptr [[IN_IMAGE]], i32 4
@@ -53,15 +53,51 @@ define dso_local range(i32 0, 458756) i32 @dspi_dotprod_off_s16_ansi(ptr nocaptu
5353
; CHECK-NEXT: br label %[[FOR_COND25_PREHEADER_US:.*]]
5454
; CHECK: [[FOR_COND25_PREHEADER_US]]:
5555
; CHECK-NEXT: [[Y_082_US:%.*]] = phi i32 [ [[INC41_US:%.*]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US:.*]] ], [ 0, %[[FOR_COND25_PREHEADER_US_PREHEADER]] ]
56-
; CHECK-NEXT: [[ACC_081_US:%.*]] = phi i64 [ [[ADD38_US_RELOAD_GENERIC:%.*]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]] ], [ 0, %[[FOR_COND25_PREHEADER_US_PREHEADER]] ]
56+
; CHECK-NEXT: [[ACC_081_US:%.*]] = phi i64 [ [[ADD_US_RELOAD:%.*]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]] ], [ 0, %[[FOR_COND25_PREHEADER_US_PREHEADER]] ]
5757
; CHECK-NEXT: [[I_DATA_080_US:%.*]] = phi ptr [ [[ADD_PTR_US:%.*]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]] ], [ [[TMP8]], %[[FOR_COND25_PREHEADER_US_PREHEADER]] ]
5858
; CHECK-NEXT: [[F_DATA_079_US:%.*]] = phi ptr [ [[ADD_PTR39_US:%.*]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]] ], [ [[TMP9]], %[[FOR_COND25_PREHEADER_US_PREHEADER]] ]
59-
; CHECK-NEXT: br label %[[CODEREPL:.*]]
60-
; CHECK: [[CODEREPL]]:
59+
; CHECK-NEXT: br label %[[CODEREPL_ENTRY:.*]]
60+
; CHECK: [[CODEREPL_ENTRY]]:
61+
; CHECK-NEXT: [[CMP_IMG_STEP1:%.*]] = icmp eq i32 [[TMP0]], 1
62+
; CHECK-NEXT: [[CMP_FILT_STEP1:%.*]] = icmp eq i32 [[TMP4]], 1
63+
; CHECK-NEXT: [[COND_STEP1:%.*]] = and i1 [[CMP_IMG_STEP1]], [[CMP_FILT_STEP1]]
64+
; CHECK-NEXT: br i1 [[COND_STEP1]], label %[[CALL_STEP1:.*]], label %[[CHECK_STEP2:.*]]
65+
; CHECK: [[CHECK_STEP2]]:
66+
; CHECK-NEXT: [[CMP_IMG_STEP2:%.*]] = icmp eq i32 [[TMP0]], 2
67+
; CHECK-NEXT: [[CMP_FILT_STEP2:%.*]] = icmp eq i32 [[TMP4]], 2
68+
; CHECK-NEXT: [[COND_STEP2:%.*]] = and i1 [[CMP_IMG_STEP2]], [[CMP_FILT_STEP2]]
69+
; CHECK-NEXT: br i1 [[COND_STEP2]], label %[[CALL_STEP2:.*]], label %[[CHECK_STEP3:.*]]
70+
; CHECK: [[CHECK_STEP3]]:
71+
; CHECK-NEXT: [[CMP_IMG_STEP3:%.*]] = icmp eq i32 [[TMP0]], 3
72+
; CHECK-NEXT: [[CMP_FILT_STEP3:%.*]] = icmp eq i32 [[TMP4]], 3
73+
; CHECK-NEXT: [[COND_STEP3:%.*]] = and i1 [[CMP_IMG_STEP3]], [[CMP_FILT_STEP3]]
74+
; CHECK-NEXT: br i1 [[COND_STEP3]], label %[[CALL_STEP3:.*]], label %[[CALL_GENERIC:.*]]
75+
; CHECK: [[CALL_STEP1]]:
76+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr [[ADD38_US_LOC]])
77+
; CHECK-NEXT: call void @dspi_dotprod_off_s16_ansi.for.body28.us(i64 [[ACC_081_US]], i32 1, ptr [[I_DATA_080_US]], i32 1, ptr [[F_DATA_079_US]], i32 [[CONV35]], i32 [[COUNT_X]], ptr [[ADD38_US_LOC]])
78+
; CHECK-NEXT: [[TMP10:%.*]] = load i64, ptr [[ADD38_US_LOC]], align 8
79+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[ADD38_US_LOC]])
80+
; CHECK-NEXT: br label %[[CALL_MERGE:.*]]
81+
; CHECK: [[CALL_STEP2]]:
82+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr [[ADD38_US_LOC]])
83+
; CHECK-NEXT: call void @dspi_dotprod_off_s16_ansi.for.body28.us(i64 [[ACC_081_US]], i32 2, ptr [[I_DATA_080_US]], i32 2, ptr [[F_DATA_079_US]], i32 [[CONV35]], i32 [[COUNT_X]], ptr [[ADD38_US_LOC]])
84+
; CHECK-NEXT: [[TMP11:%.*]] = load i64, ptr [[ADD38_US_LOC]], align 8
85+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[ADD38_US_LOC]])
86+
; CHECK-NEXT: br label %[[CALL_MERGE]]
87+
; CHECK: [[CALL_STEP3]]:
88+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr [[ADD38_US_LOC]])
89+
; CHECK-NEXT: call void @dspi_dotprod_off_s16_ansi.for.body28.us(i64 [[ACC_081_US]], i32 3, ptr [[I_DATA_080_US]], i32 3, ptr [[F_DATA_079_US]], i32 [[CONV35]], i32 [[COUNT_X]], ptr [[ADD38_US_LOC]])
90+
; CHECK-NEXT: [[TMP12:%.*]] = load i64, ptr [[ADD38_US_LOC]], align 8
91+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[ADD38_US_LOC]])
92+
; CHECK-NEXT: br label %[[CALL_MERGE]]
93+
; CHECK: [[CALL_GENERIC]]:
6194
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr [[ADD38_US_LOC]])
6295
; CHECK-NEXT: call void @dspi_dotprod_off_s16_ansi.for.body28.us(i64 [[ACC_081_US]], i32 [[TMP0]], ptr [[I_DATA_080_US]], i32 [[TMP4]], ptr [[F_DATA_079_US]], i32 [[CONV35]], i32 [[COUNT_X]], ptr [[ADD38_US_LOC]])
63-
; CHECK-NEXT: [[ADD38_US_RELOAD_GENERIC]] = load i64, ptr [[ADD38_US_LOC]], align 8
96+
; CHECK-NEXT: [[ADD38_US_RELOAD_GENERIC:%.*]] = load i64, ptr [[ADD38_US_LOC]], align 8
6497
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr [[ADD38_US_LOC]])
98+
; CHECK-NEXT: br label %[[CALL_MERGE]]
99+
; CHECK: [[CALL_MERGE]]:
100+
; CHECK-NEXT: [[ADD_US_RELOAD]] = phi i64 [ [[TMP10]], %[[CALL_STEP1]] ], [ [[TMP11]], %[[CALL_STEP2]] ], [ [[TMP12]], %[[CALL_STEP3]] ], [ [[ADD38_US_RELOAD_GENERIC]], %[[CALL_GENERIC]] ]
65101
; CHECK-NEXT: br label %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]]
66102
; CHECK: [[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]]:
67103
; CHECK-NEXT: [[ADD_PTR_US]] = getelementptr inbounds i16, ptr [[I_DATA_080_US]], i32 [[MUL20]]
@@ -70,7 +106,7 @@ define dso_local range(i32 0, 458756) i32 @dspi_dotprod_off_s16_ansi(ptr nocaptu
70106
; CHECK-NEXT: [[EXITCOND85_NOT:%.*]] = icmp eq i32 [[INC41_US]], [[COUNT_Y]]
71107
; CHECK-NEXT: br i1 [[EXITCOND85_NOT]], label %[[FOR_COND_CLEANUP]], label %[[FOR_COND25_PREHEADER_US]]
72108
; CHECK: [[FOR_COND_CLEANUP]]:
73-
; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i64 [ 0, %[[IF_END16]] ], [ 0, %[[FOR_COND25_PREHEADER_LR_PH]] ], [ [[ADD38_US_RELOAD_GENERIC]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]] ]
109+
; CHECK-NEXT: [[ACC_0_LCSSA:%.*]] = phi i64 [ 0, %[[IF_END16]] ], [ 0, %[[FOR_COND25_PREHEADER_LR_PH]] ], [ [[ADD_US_RELOAD]], %[[FOR_COND25_FOR_COND_CLEANUP27_CRIT_EDGE_US]] ]
74110
; CHECK-NEXT: [[SUB:%.*]] = add nsw i32 [[SHIFT]], -1
75111
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
76112
; CHECK-NEXT: [[CONV43:%.*]] = sext i32 [[SHL]] to i64
@@ -176,19 +212,19 @@ return: ; preds = %for.cond.cleanup, %
176212
}
177213

178214
; Function Attrs: nofree norecurse nounwind
179-
define internal void @dspi_dotprod_off_s16_ansi.for.body28.us(i64 %acc.081.us, i32 %0, ptr %i_data.080.us, i32 %1, ptr %f_data.079.us, i32 %conv35, i32 %count_x, ptr %add38.us.out) {
215+
define internal void @dspi_dotprod_off_s16_ansi.for.body28.us(i64 %acc.081.us, i32 %img_step, ptr %i_data.080.us, i32 %filt_step, ptr %f_data.079.us, i32 %conv35, i32 %count_x, ptr %add38.us.out) {
180216
; CHECK-LABEL: define internal void @dspi_dotprod_off_s16_ansi.for.body28.us(
181-
; CHECK-SAME: i64 [[ACC_081_US:%.*]], i32 [[TMP0:%.*]], ptr [[I_DATA_080_US:%.*]], i32 [[TMP1:%.*]], ptr [[F_DATA_079_US:%.*]], i32 [[CONV35:%.*]], i32 [[COUNT_X:%.*]], ptr [[ADD38_US_OUT:%.*]]) {
217+
; CHECK-SAME: i64 [[ACC_081_US:%.*]], i32 [[IMG_STEP:%.*]], ptr [[I_DATA_080_US:%.*]], i32 [[FILT_STEP:%.*]], ptr [[F_DATA_079_US:%.*]], i32 [[CONV35:%.*]], i32 [[COUNT_X:%.*]], ptr [[ADD38_US_OUT:%.*]]) {
182218
; CHECK-NEXT: [[NEWFUNCROOT:.*]]:
183219
; CHECK-NEXT: br label %[[FOR_BODY28_US:.*]]
184220
; CHECK: [[FOR_BODY28_US]]:
185221
; CHECK-NEXT: [[X_077_US:%.*]] = phi i32 [ 0, %[[NEWFUNCROOT]] ], [ [[INC_US:%.*]], %[[FOR_BODY28_US]] ]
186222
; CHECK-NEXT: [[ACC_176_US:%.*]] = phi i64 [ [[ACC_081_US]], %[[NEWFUNCROOT]] ], [ [[ADD38_US:%.*]], %[[FOR_BODY28_US]] ]
187-
; CHECK-NEXT: [[MUL30_US:%.*]] = mul nsw i32 [[X_077_US]], [[TMP0]]
223+
; CHECK-NEXT: [[MUL30_US:%.*]] = mul nsw i32 [[X_077_US]], [[IMG_STEP]]
188224
; CHECK-NEXT: [[ARRAYIDX_US:%.*]] = getelementptr inbounds i16, ptr [[I_DATA_080_US]], i32 [[MUL30_US]]
189225
; CHECK-NEXT: [[TMP2:%.*]] = load i16, ptr [[ARRAYIDX_US]], align 2
190226
; CHECK-NEXT: [[CONV_US:%.*]] = sext i16 [[TMP2]] to i32
191-
; CHECK-NEXT: [[MUL32_US:%.*]] = mul nsw i32 [[X_077_US]], [[TMP1]]
227+
; CHECK-NEXT: [[MUL32_US:%.*]] = mul nsw i32 [[X_077_US]], [[FILT_STEP]]
192228
; CHECK-NEXT: [[ARRAYIDX33_US:%.*]] = getelementptr inbounds i16, ptr [[F_DATA_079_US]], i32 [[MUL32_US]]
193229
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[ARRAYIDX33_US]], align 2
194230
; CHECK-NEXT: [[CONV34_US:%.*]] = sext i16 [[TMP3]] to i32
@@ -209,11 +245,11 @@ newFuncRoot:
209245
for.body28.us: ; preds = %newFuncRoot, %for.body28.us
210246
%x.077.us = phi i32 [ 0, %newFuncRoot ], [ %inc.us, %for.body28.us ]
211247
%acc.176.us = phi i64 [ %acc.081.us, %newFuncRoot ], [ %add38.us, %for.body28.us ]
212-
%mul30.us = mul nsw i32 %x.077.us, %0
248+
%mul30.us = mul nsw i32 %x.077.us, %img_step
213249
%arrayidx.us = getelementptr inbounds i16, ptr %i_data.080.us, i32 %mul30.us
214250
%2 = load i16, ptr %arrayidx.us, align 2
215251
%conv.us = sext i16 %2 to i32
216-
%mul32.us = mul nsw i32 %x.077.us, %1
252+
%mul32.us = mul nsw i32 %x.077.us, %filt_step
217253
%arrayidx33.us = getelementptr inbounds i16, ptr %f_data.079.us, i32 %mul32.us
218254
%3 = load i16, ptr %arrayidx33.us, align 2
219255
%conv34.us = sext i16 %3 to i32

0 commit comments

Comments
 (0)