Skip to content

Conversation

@lengmuzhaxi
Copy link

@lengmuzhaxi lengmuzhaxi commented Aug 27, 2025

This PR introduces an RVV (RISC-V Vector Extension) optimized implementation for the MNNDeconvRunForUnitDepthWise function, aiming to significantly accelerate depthwise deconvolution operations on RISC-V platforms by leveraging SIMD vector instructions.

Optimization Strategy

We focused on vectorizing the core computation using RVV intrinsics:

Vector length (vl): Set to 4 floats for optimal SIMD utilization.

Load/store optimization: Source (src) and weight (weight) arrays are loaded as vfloat32m1_t vectors.

Fused multiply-accumulate: Each src vector is updated with vsrc = vsrc + vdst * vweight using vfmacc_vv_f32m1.

This approach replaces the previous Vec4 SIMD-based implementation for better performance on RISC-V vector hardware.

Performance Summary

Test Hardware: Banana Pi BPI-F3
Operating System: EulixOS 3.0
Test Coverage: Test cases covered a wide range of matrix configurations, from small to large sizes, with various aspect ratios and strides.
Key Findings
Across all test scenarios, the non-unrolled scheme (Scheme 0) demonstrated the best and most stable performance, with its advantage being most pronounced on large matrices.

For large square matrices (e.g., 65536*65536, 512x512, 1024x1024), this scheme achieved a consistent speedup of approximately 2.5x compared to the scalar version.
Representative Data
Detailed Performance for Large Filters

The following data for large input widths and filter sizes clearly illustrates the superiority of the RVV-optimized implementation over the original Vec4 version:
1024x1024 Matrix Multiplication
widthC4=1024, height=1024
Scalar time:0.007185 ms
RVV time:0.002086 ms
SpeedUp:3.44
512x512 Matrix Multiplication
widthC4=512, height=512
Scalar time:0.003550 ms
RVV time: 0.000794 ms
Speedup:4.47
655536*65536 Matrix Multiplication
widthC4=65536, height=65536
Scalar time:0.495102 ms
RVV time:0.192677ms
SpeedUp:2.57
Future Work
This submission is part of an ongoing effort to optimize MNN functions using RVV. We will continue to optimize other core functions to comprehensively enhance MNN's inference performance on the RISC-V platform.

@jxt1234
Copy link
Collaborator

jxt1234 commented Sep 22, 2025

The branch has conficts. Please update it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants