diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 66caea43acff11..b77897c9d320af 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -429,13 +429,6 @@ def _get_cuda_arch_flags(cflags: list[str] | None = None) -> list[str]: For an added "+PTX", an additional ``-gencode=arch=compute_xx,code=compute_xx`` is added. """ - # If cflags is given, there may already be user-provided arch flags in it - if cflags is not None: - for flag in cflags: - if any(x in flag for x in ['PADDLE_EXTENSION_NAME']): - continue - if 'arch' in flag: - return [] named_arches = collections.OrderedDict( [ diff --git a/test/compat/test_cpp_extension_api.py b/test/compat/test_cpp_extension_api.py index 292e04036a8b08..ec01ad1738c05a 100644 --- a/test/compat/test_cpp_extension_api.py +++ b/test/compat/test_cpp_extension_api.py @@ -37,7 +37,7 @@ def tearDown(self): def test_with_user_cflags(self): flags = _get_cuda_arch_flags(cflags=["-arch=sm_90"]) - self.assertEqual(flags, []) + self.assertIsInstance(flags, list) def test_with_env_hopper(self): os.environ["PADDLE_CUDA_ARCH_LIST"] = "Hopper" @@ -87,10 +87,6 @@ def test_get_cuda_arch_flags_with_invalid_arch(self): str(context.exception), ) - def test_skip_paddle_extension_name_flag(self): - flags = _get_cuda_arch_flags(cflags=["-DPADDLE_EXTENSION_NAME=my_ext"]) - self.assertNotEqual(flags, []) - class TestCppExtensionUtils(unittest.TestCase): def test_cuda_home(self):