resolve multi-CUDA_ARCHITECTURES compilation conflicts#241
resolve multi-CUDA_ARCHITECTURES compilation conflicts#241GACLove wants to merge 1 commit intothu-ml:mainfrom
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR resolves CUDA compilation conflicts when building for multiple GPU architectures by isolating build artifacts and adding validation for CUDA_ARCHITECTURES environment variable.
- Enables building for multiple CUDA architectures by reading from CUDA_ARCHITECTURES environment variable
- Isolates build artifacts per architecture in separate subdirectories to prevent conflicts
- Filters NVCC flags to only include relevant architecture-specific flags for each extension
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| if cuda_architectures is not None: | ||
| for arch in cuda_architectures.split(","): | ||
| arch = arch.strip() | ||
| if arch: |
There was a problem hiding this comment.
The code doesn't validate that each architecture value is a valid decimal format as mentioned in the PR description. Consider adding validation to ensure each arch value matches the expected pattern (e.g., regex check for decimal format).
| if arch: | |
| if arch: | |
| if not re.match(r"^\d+\.\d+$", arch): | |
| raise ValueError(f"Invalid architecture value '{arch}' in CUDA_ARCHITECTURES. Expected decimal format like '8.0', '8.6', etc.") |
| skip_next = True | ||
| elif flag not in ["-gencode"]: | ||
| filtered_flags.append(flag) | ||
|
|
There was a problem hiding this comment.
The flag filtering logic is duplicated between sm89 and sm90 extensions. Consider extracting this into a helper function to reduce code duplication and improve maintainability.
| sm89_arch_list = ["sm_89", "compute_89", "sm_90a", "compute_90a", "sm_120", "compute_120"] | |
| filtered_flags = filter_nvcc_flags(NVCC_FLAGS, sm89_arch_list) |
setup.py
Outdated
| filtered_flags.append(arch_flag) | ||
| skip_next = True | ||
| elif flag not in ["-gencode"]: | ||
| filtered_flags.append(flag) |
There was a problem hiding this comment.
This is duplicate code from the sm89 extension filtering logic. The same flag filtering pattern should be extracted into a reusable function.
| filtered_flags.append(flag) | |
| def filter_nvcc_flags_for_arch(nvcc_flags, arch_substrings): | |
| filtered_flags = [] | |
| skip_next = False | |
| for i, flag in enumerate(nvcc_flags): | |
| if skip_next: | |
| skip_next = False | |
| continue | |
| if flag == "-gencode": | |
| if i + 1 < len(nvcc_flags): | |
| arch_flag = nvcc_flags[i + 1] | |
| if any(sub in arch_flag for sub in arch_substrings): | |
| filtered_flags.append(flag) | |
| filtered_flags.append(arch_flag) | |
| skip_next = True | |
| elif flag not in ["-gencode"]: | |
| filtered_flags.append(flag) | |
| return filtered_flags | |
| filtered_flags = filter_nvcc_flags_for_arch(NVCC_FLAGS, ["sm_90a", "compute_90a"]) |
setup.py
Outdated
| "output_dir": os.path.join( | ||
| kwargs["output_dir"], | ||
| self.thread_ext_name_map[threading.current_thread().ident]), | ||
| self.thread_ext_name_map.get(threading.current_thread().ident, "default")), |
There was a problem hiding this comment.
Using 'default' as a fallback directory name could lead to conflicts if multiple threads don't have mapped extension names. Consider using a more unique identifier like thread ID or timestamp.
| self.thread_ext_name_map.get(threading.current_thread().ident, "default")), | |
| self.thread_ext_name_map.get( | |
| threading.current_thread().ident, | |
| f"thread_{threading.current_thread().ident}" | |
| )), |
|
testing pip install -v --no-cache-dir .
CUDA_ARCHITECTURES="9.0,12.0" pip install -v --no-cache-dir .
CUDA_ARCHITECTURES="8.9,9.0" pip install -v --no-cache-dir . |
…d architectures via environment variable. Refactor GPU capability checks and streamline NVCC flags for SM89 and SM90 extensions. Improve build process by creating separate output directories for extensions.
fix(build): resolve multi-CUDA_ARCHITECTURES compilation conflicts
build/sm_{arch}/subdirectories