Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,48 @@
from pathlib import Path

import paddle
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, setup
from paddle.utils.cpp_extension import (
CppExtension,
CUDAExtension,
extension_utils,
setup,
)
from setuptools import find_namespace_packages, find_packages

# Workaround for Paddle PR #78704:
# Paddle 3.5.0.dev20260418+ changed CUDAExtension behavior to auto-add gencode flags
# based on PADDLE_CUDA_ARCH_LIST even when user provides arch flags in cflags.
# This causes relocation overflow in large CUDA files (e.g., append_attention.cu).
#
# This patch suppresses Paddle's auto-gencode addition when user-provided gencode
# flags are detected, preventing duplicate architecture specifications.
_original_get_cuda_arch_flags = extension_utils._get_cuda_arch_flags


def _patched_get_cuda_arch_flags(cflags=None):
"""
Patched version that returns empty list when user-provided gencode flags are detected.

This prevents Paddle from auto-adding duplicate gencode flags based on
PADDLE_CUDA_ARCH_LIST, which would cause relocation overflow errors.
"""
if cflags:
for flag in cflags:
if isinstance(flag, str) and (flag.startswith("-gencode") or "compute_" in flag or "sm_" in flag):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 flag 检测逻辑存在误匹配风险。

"compute_" in flag"sm_" in flag 使用子串匹配,可能误匹配非 gencode 标志(例如包含路径中含 sm_compute_-I include 路径)。虽然目前场景下概率较低,但作为通用 patch 函数可以更精确。

建议收紧匹配条件,例如仅检查 -gencode-arch 前缀:

if isinstance(flag, str) and (flag.startswith("-gencode") or flag.startswith("-arch")):
    return []

return []
return _original_get_cuda_arch_flags(cflags)


extension_utils._get_cuda_arch_flags = _patched_get_cuda_arch_flags


# Additional safeguard (important):
# Some Paddle versions may have additional internal methods that add gencode flags.
# This patch serves as a second line of defense by overriding such methods.
if hasattr(extension_utils, "CUDAExtension"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 extension_utils 模块上不太可能存在 CUDAExtension 类属性。

CUDAExtension 是从 paddle.utils.cpp_extension 导入的函数/类,而非 extension_utils 模块的属性。因此 hasattr(extension_utils, "CUDAExtension") 大概率为 False,这段 "第二道防线" 实际上是死代码,无法起到防护作用。

建议:

  1. 确认 Paddle 对应版本中 extension_utils 模块是否确实有 CUDAExtension 属性,如果没有则移除这段代码以避免误导;
  2. 如果确需额外防护,可考虑直接 patch paddle.utils.cpp_extension.CUDAExtension 本身。

if hasattr(extension_utils.CUDAExtension, "_add_cuda_arch_flags"):
extension_utils.CUDAExtension._add_cuda_arch_flags = lambda self, flags: flags


def load_module_from_path(module_name, path):
"""
Expand Down
Loading