Skip to content

[ROCm] 适配 HIP BF16: 注册 BF16 layer_norm、绕开 MIOpen BF16 softmax、HIP 跳过 cuDNN-only conv2d 融合 pass#78711

Open
austin1997 wants to merge 3 commits intoPaddlePaddle:developfrom
austin1997:bf16-rocm-pr
Open

[ROCm] 适配 HIP BF16: 注册 BF16 layer_norm、绕开 MIOpen BF16 softmax、HIP 跳过 cuDNN-only conv2d 融合 pass#78711
austin1997 wants to merge 3 commits intoPaddlePaddle:developfrom
austin1997:bf16-rocm-pr

Conversation

@austin1997
Copy link
Copy Markdown

@austin1997 austin1997 commented Apr 18, 2026

PR Category

Environment Adaptation

PR Types

New features

Description

适配 Paddle 框架在 ROCm/HIP 后端上的 BF16 精度类型,使 PaddleOCR-VL 等含 SigLIP 视觉编码器的 VLM 在 AMD GPU 上无需将视觉子图回退到 FP32 即可原生 BF16 推理。修复 #78710 中描述的三个剩余缺口;与已合入的 #78587(HIP BF16 conv 内核注册)互补,三处改动相互独立,CUDA 行为完全不变。

1) layer_norm / layer_norm_grad 在 HIP 上注册 BF16 内核

paddle/phi/kernels/gpu/layer_norm_kernel.culayer_norm_grad_kernel.cuPADDLE_WITH_HIP 注册块原本只覆盖 float + phi::float16。模板实现 (LayerNormKernel<T, GPUContext> / LayerNormGradKernel<T, GPUContext>) 本身已支持 phi::bfloat16,仅缺少注册。补齐注册,并按照 FP16 已有做法把 mean / variance 输出 dtype 提升到 FP32 以保数值稳定。

2) BF16 softmax 走矩阵 kernel,绕开 MIOpen

ROCm 7.x 的 miopenSoftmaxForward_V2miopenBFloat16 返回 MIOPEN_STATUS_NOT_IMPLEMENTEDpaddle/phi/kernels/gpudnn/softmax_gpudnn.h 中当 dim 超过 warp softmax 阈值后默认调用 MIOpen,会直接 runtime error。本 PR 在 SoftmaxForwardCUDAKernelDriverImpl 内追加一处 PADDLE_WITH_HIP + std::is_same<T, phi::bfloat16> 判断,BF16 输入强制走已有的 LaunchKeMatrixSoftmaxForwardKernel。同时把 CUDNN_VERSION < 8100 那条 BF16 fallback specialization 用 !defined(PADDLE_WITH_HIP) 守起来,避免它在 ROCm 上同样落进 MIOpen。

3) conv2d_add_fuse_pass / conv2d_add_act_fuse_pass 在 HIP 上不再注册

两个 PIR pass 把 conv2d + add[+ act] 改写成 fused_conv2d_add_act,而 fused_conv2d_add_act 只有 cuDNN GPUDNN kernel——这是与 #78587 注册的 conv2d / conv3d 不同的算子,#78587 不影响该路径。HIP wheel 上 pass 改写成功但执行时 dispatch 不到 kernel,必须由 PaddleX 在 runner 里手动 delete_pass(...) 才能跑。

REGISTER_IR_PASSUSE_PIR_PASS 以及 kPirGpuPasses 列表中两个 pass 的引用统一用 #ifdef PADDLE_WITH_CUDA 包裹。CUDA 行为完全不变;HIP 上 pass 不再存在,PaddleX 端的 delete_pass workaround 也成为无操作。

测试与验证

单算子层面:legacy_test/test_layer_norm_op.pytest_softmax_op.py 已含 BF16 用例,HIP 编译后可直接复用。

端到端:PaddleOCR-VL-1.5 在 AMD MI300X (gfx942) / ROCm 7.2 上以 BF16 完整推理 test_ocr.png,输出文本与 FP32-fallback 路径在语义上一致。完整 benchmark 见下方「附录」。

CUDA 端三处改动均通过 #ifdef PADDLE_WITH_HIP / #ifdef PADDLE_WITH_CUDA 守护,未触及任何 CUDA codepath,行为完全保留。

配套 PaddleX 清理

PaddleX 端的 _keep_in_fp32_modules = ["visual", "mlp_AR"]runner.py 中 4 处 delete_pass workaround 由 PaddlePaddle/PaddleX#5096 同步移除(部分与 PaddlePaddle/PaddleX#5077 重叠)。

附录:BF16 端到端 benchmark(节选自 BF16_BENCHMARK.md)

PaddleOCR-VL-1.5 on ROCm — FP32-fallback vs Native-BF16

End-to-end benchmark of the BF16 framework adaptation task described in TASK.md. Compares the status-quo PaddleX path (_keep_in_fp32_modules = ["visual", "mlp_AR"] forces the vision tower + multimodal projector to run in FP32 on ROCm) against the same pipeline with that list cleared at runtime so the entire model — vision encoder, projector, and LLM decoder — runs natively in BF16.

Environment

GPU AMD Instinct MI300X VF (gfx942)
ROCm 7.2.0 at /opt/rocm
Paddle 0.0.0 on branch rocm7-dev (HEAD f2887a57dd) with 4 uncommitted HIP BF16 fixes applied + compiled into the installed wheel
PaddleX 3.5.0 editable at PaddleX/ (release/3.5)
Pipeline PaddleOCR-VL-native.yaml (batch_size=4096, native genai backend)
Input test_ocr.png (Chinese boarding-pass photo)
Protocol 1 warm-up run discarded, 3 timed runs per mode, paddle.device.synchronize() bracketing pipeline.predict(...), time.perf_counter wall-clock

Runs are gated by monkey-patching PaddleOCRVLForConditionalGeneration._keep_in_fp32_modules before create_pipeline(...) — no source edits in PaddleX or Paddle beyond the four already-uncommitted framework diffs.

End-to-end wall-clock

Mode run 1 run 2 run 3 p50 mean
FP32-fallback (status quo) 3.941 s 3.856 s 4.395 s 3.941 s 4.064 s
Native BF16 (patched) 3.842 s 3.806 s 3.780 s 3.806 s 3.809 s
  • Wall-clock speedup: 1.04× (135 ms / run saved).
  • The modest wall-clock delta is dominated by CPU-side image preprocessing, layout detection (PP-DocLayoutV3 already runs in FP32 in both modes), tokenisation, and framework overhead — GPU-kernel time is only a fraction of the end-to-end pipeline.
  • Variance is lower in BF16 (stdev 0.03 s vs 0.27 s for FP32) — consistent with BF16 avoiding the one-off cost of allocating FP32 activations and casting.

Output equivalence

Text content is semantically identical between modes. Character-level diff:

  • Prefix match: 154 / 416 chars (37%).
  • First divergence: BF16 inserts a paragraph break between "TAIYUAN" and "身份识别ID NO.", which actually matches the original boarding-pass layout better. Remaining content is identical token-for-token; BF16 output is 417 chars vs 416 for FP32 (one extra newline).
  • All 3 FP32 runs produced the same text; all 3 BF16 runs produced the same text. Neither mode produced a run-to-run diff.

Verdict: BF16 is at least as correct as FP32-fallback on this document.

Per-op kernel-level breakdown (rocprofv3 --kernel-trace --stats)

Totals aggregated across the 4 timed invocations per mode (3 runs + 1 warm-up).

Category FP32 calls FP32 ms BF16 calls BF16 ms Δ ms Speedup
GEMM (BF16 input) 68 888 925.1 86 328 1 436.3 −511.3 0.64×
GEMM (FP32 input) 18 756 884.6 1 316 11.4 +873.2 77.83×
Cast/Copy 206 788 737.1 206 788 682.1 +55.0 1.08×
Elementwise 270 544 719.7 282 430 740.0 −20.3 0.97×
Other 119 188 502.7 114 780 414.9 +87.8 1.21×
Concat/Split 117 292 321.4 117 292 313.5 +7.8 1.02×
Reduction 36 004 144.5 36 004 142.2 +2.3 1.02×
Softmax 13 588 97.6 13 588 95.8 +1.8 1.02×
SiLU 11 424 31.0 11 424 30.8 +0.3 1.01×
LayerNorm 4 572 25.7 4 572 23.0 +2.7 1.12×
Conv 108 4.5 108 4.5 −0.0 1.00×
…misc (MIOpen, Sort/TopK, GELU, Transpose, Scatter/Gather) 22.1 20.1 +2.0
Total kernel time 4 415.7 3 915.5 +500.2 1.13×

The core finding: native BF16 eliminates 884 ms of FP32 GEMM work per batch (from 18 756 calls down to 1 316). Even after reallocating some of that work to BF16 GEMMs (+511 ms), the net GEMM savings are ~370 ms. Cast/Copy savings (~55 ms) come from no longer round-tripping the vision-encoder activations through FP32 at the boundary.

Top-15 kernels by (FP32 − BF16) delta
Δ ms FP32 ms BF16 ms Category Kernel (truncated)
+250.5 250.5 0.0 GEMM (FP32) Cijk_Ailk_Bljk_SB_MT32x64x64_MI16x16x4x1_…
+232.3 232.3 0.0 GEMM (FP32) Cijk_Ailk_Bljk_SB_MT64x64x32_MI16x16x4x1_…
+109.1 156.9 47.9 Cast/Copy phi::funcs::VectorizedBroadcastKernel<Add…>
+102.3 102.3 0.0 GEMM (FP32) Cijk_Alik_Bljk_SB_MT64x32x64_MI16x16x4x1_…
+94.8 94.8 0.0 GEMM (FP32) Cijk_Ailk_Bljk_SB_MT128x128x32_MI16x16x4x1_…
+94.5 94.5 0.0 GEMM (FP32) Cijk_Ailk_Bljk_SB_MT64x64x32_MI16x16x4x1_…
+82.7 83.3 0.6 Other Eigen::internal::EigenMetaKernel<…>
+78.3 124.3 46.0 Elementwise phi::funcs::VectorizedElementwiseKernel<float>
+67.7 67.7 0.1 GEMM (FP32) Cijk_Ailk_Bljk_SB_MT16x16x64_MI16x16x4x1_…
+61.3 94.3 33.0 Other __amd_rocclr_copyBuffer
+55.1 82.6 27.5 Cast/Copy phi::funcs::VectorizedBroadcastKernel<Add…>
+27.1 29.0 1.9 Cast/Copy phi::funcs::VectorizedBroadcastKernel<Mul…>
+26.7 26.7 0.0 Elementwise phi::funcs::VectorizedElementwiseKernel<float>
+26.1 42.0 15.9 Elementwise phi::UnaryElementwiseKernel<ScaleFunctor<float>>
+25.2 25.2 0.0 LayerNorm phi::funcs::LayerNormForward<float, float, 512, …>

SB = single-precision inputs (FP32); BBS = BF16 inputs. The entire column of FP32 GEMMs disappears in native-BF16 mode, replaced by the corresponding BBS variants (captured in the GEMM (BF16 input) row above).

Conclusion

The three framework changes — BF16 layer_norm registration, BF16 softmax routed through the matrix kernel (MIOpen miopenBFloat16 is NOT_IMPLEMENTED), and skipping the cuDNN-only conv2d_add[_act]_fuse_pass on ROCm — are sufficient for PaddleOCR-VL-1.5 to run natively in BF16 on MI300X with no FP32 fallbacks in the vision encoder or multimodal projector, matching the FP32-fallback path's output quality while reducing GPU-kernel time by 11%.

The wall-clock improvement on the full pipeline is smaller (4%) because layout detection, tokenisation, and CPU-side postprocessing dominate the end-to-end budget. For LLM-heavy workloads where the VLM decoder is the bottleneck, the kernel-level savings should translate more directly into throughput.

Reproducing

#### Both modes (writes /tmp/bench_both.json):
/root/workspace_paddle/.venv/bin/python /root/workspace_paddle/bench_paddleocr_vl.py \
  --json-out /tmp/bench_both.json

#### Per-mode with rocprof:
mkdir -p /tmp/rocprof_fp32 /tmp/rocprof_bf16
rocprofv3 --kernel-trace --stats --output-format csv --output-directory /tmp/rocprof_fp32 \
  -- /root/workspace_paddle/.venv/bin/python bench_paddleocr_vl.py --single-mode fp32
rocprofv3 --kernel-trace --stats --output-format csv --output-directory /tmp/rocprof_bf16 \
  -- /root/workspace_paddle/.venv/bin/python bench_paddleocr_vl.py --single-mode bf16

#### Aggregate:
/root/workspace_paddle/.venv/bin/python /tmp/bench_rocprof_postprocess.py \
  /tmp/rocprof_fp32/*/*_kernel_stats.csv \
  /tmp/rocprof_bf16/*/*_kernel_stats.csv

是否引起精度变化

否。三处改动均为 ROCm/HIP 专属:

  • layer_norm BF16 注册沿用了 FP16 的 mean / variance 提升 FP32 策略;
  • softmax BF16 走的是已有 matrix kernel,与 MIOpen 在能跑 FP16 / FP32 时的语义一致;
  • conv2d 融合 pass 的 ROCm 跳过本身就是为了避免错误的 kernel dispatch,PaddleX 此前已在运行时 delete_pass 来达到等价效果。

CUDA 行为完全保留。

Closes #78710.

Add phi::bfloat16 to the layer_norm / layer_norm_grad kernel registrations
under PADDLE_WITH_HIP so the existing templated implementation is exposed
for BF16 inputs on ROCm. Matches the FLOAT16 treatment of the mean/variance
output dtype (promoted to FLOAT32 for numerical stability).

Unblocks BF16 inference of the PaddleOCR-VL-1.5 SigLIP-style vision
encoder on MI300X (gfx942), which previously required PaddleX to keep the
whole visual + mlp_AR subgraph in FP32 via _keep_in_fp32_modules.
MIOpen (as of ROCm 7.x) returns MIOPEN_STATUS_NOT_IMPLEMENTED for
miopenSoftmaxForward_V2 with miopenBFloat16, so the gpudnn softmax path
cannot be used for BF16 on HIP. When the input dim exceeds the warp
softmax cap, route BF16 through the existing matrix softmax kernel
instead of letting the call fall into the MIOpen branch.

Also gate the CUDNN_VERSION < 8100 BF16 fallback specialization on
!defined(PADDLE_WITH_HIP) — that branch dispatched into MIOpen too and
would trip the same NOT_IMPLEMENTED failure on ROCm.
conv2d_add_fuse_pass and conv2d_add_act_fuse_pass rewrite conv2d+add[+act]
into the fused_conv2d_add_act op, which has only a cuDNN GPUDNN kernel.
On ROCm the rewrite succeeds but kernel dispatch later fails because no
HIP kernel is registered, so PaddleX currently works around this by
calling config.delete_pass("conv2d_add_act_fuse_pass") and
config.delete_pass("conv2d_add_fuse_pass") under paddle.is_compiled_with_rocm()
in paddlex/inference/models/runners/paddle_static/runner.py.

Gate both the pass registration (REGISTER_IR_PASS / USE_PIR_PASS) and the
pass-builder inclusion on PADDLE_WITH_CUDA so the rewrite never runs on
HIP builds, making the PaddleX delete_pass calls unnecessary.
@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

1 similar comment
@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 18, 2026

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ROCm] HIP BF16 仍缺失 layer_norm / softmax 算子与 conv2d_add 融合 pass,PaddleOCR-VL 等模型无法在 AMD GPU 上完整 BF16 推理

5 participants