Updated setup.py and pyproject.toml#279
Conversation
Changing authors and maintainers to reflect the project. Adding email now that domain supports forwarding. Updating classifiers Version matching torch and cmake
This is the setup.py file from PR SHI-Labs#273 Updated with the minor change needed for blackwell ultra Minor change to src/natten/profiler: f-strings had use of " inside f (e.g. f"{", ".join(...)}.") This is not supported in all python versions and is bad practice as it is less readable.
| readme = {file = "docs/README_pypi.md", content-type = "text/markdown"} | ||
|
|
||
| dependencies = [ | ||
| "torch>=2.8.0", |
There was a problem hiding this comment.
| "torch>=2.8.0", | |
| "torch", |
Can we relax this for now and just make it any version of torch?
Some pypi versions are stupid and will error out if the system constraints have locked the torch version.
The best part is that the version requirements are also met.
But anyway, we do technically support 2.7, even 2.6.
If folks aren't using Flex, we can probably go even as far back as 2.0. They'd have to compile it themselves of course.
There was a problem hiding this comment.
Added to local stage and placing comments for additional clarity
| "setuptools >= 64", | ||
| "torch >= 2.7", | ||
| "cmake >= 4.0", | ||
| "torch >= 2.8", |
There was a problem hiding this comment.
| "torch >= 2.8", | |
| "torch", |
Same here. Let's accept any version of torch, and just raise warnings/errors in setup/cmake if it's unsupported... I really don't want to trust pypi to do the right thing here.
| MIN_TORCH_VERSION : float = 2.6 | ||
| MIN_CUDA_VERSION : float = 12.0 | ||
| MIN_SM : int = 30 | ||
| SUPPORTED_GPU_ARCH : list[int] = [90, 100, 103] |
There was a problem hiding this comment.
Can we rename SUPPORTED_GPU_ARCH to something like SMS_WITH_ARCH_SPECIFIC_FEATS?
I'm also very supportive of just using ARCH instead of SM in this entire file (SMS is hard to parse, even for me).
There was a problem hiding this comment.
This is actually slightly more complicated.
SM90 and SM100/103 have backends that are specific to them.
But it is possible to get backends that not all of our supported architectures can run. For instance #278 will add new backends that will run on SM80 and later.
The thing that's special about the SM90 and SM100/103 backends is that they can only support one architecture (or arch family, but we're not using that concept directly here), and therefore their arch tags will have that a appended to enable the arch-specific ISA.
I'm thinking maybe we could have a list of backends, with architectures that are supported.
What gets tricky is that we don't want to list out all the arches in backends like fna/fmha because it would be everything.
We could do a min and max arch... but I don't want to assume it's always contiguous.
I.e. It's unclear to me what Sm101 and Sm102 are, and whether they can support that backend -- and these numbers changed between CTK 12.8 and 13.0 too....
| _AG_POLICY_TUNABLES = { | ||
| "reference": 2, | ||
| "fna": 64, | ||
| "fmha": 6, | ||
| "hopper-fna": 8, | ||
| "hopper-fna-bwd": 4, | ||
| "blackwell-fna": 28, | ||
| "blackwell-fna-bwd": 14, | ||
| } | ||
|
|
||
| if not HAS_CUDA_ARCH: | ||
| HAS_CUDA = torch.cuda.is_available() | ||
| _AG_POLICIES_CONSTS = { | ||
| "hopper-fmha": 1, | ||
| "hopper-fmha-bwd": 1, | ||
| "blackwell-fmha": 1, | ||
| "blackwell-fmha-bwd": 1, | ||
| } | ||
|
|
||
| if HAS_CUDA: | ||
| cuda_device = torch.cuda.get_device_properties(torch.cuda.current_device()) | ||
| sm = cuda_device.major + cuda_device.minor * 0.1 | ||
| CUDA_ARCH = f"{sm}" | ||
| print( | ||
| "`NATTEN_CUDA_ARCH` not set, but detected CUDA driver with PyTorch. " | ||
| f"Building for {CUDA_ARCH=}." | ||
| ) | ||
| AG_POLICY_DEFAULT = _AG_POLICIES_CONSTS | _AG_POLICY_TUNABLES | ||
|
|
||
| assert torch.version.cuda is not None | ||
| TORCH_CUDA_VERSION = [x for x in torch.version.cuda.split(".")[:2]] | ||
| CUDA_TAG = "".join([x for x in TORCH_CUDA_VERSION]) | ||
| CUDA_VERSION = [int(x) for x in TORCH_CUDA_VERSION] | ||
|
|
||
| assert CUDA_VERSION >= [12, 0], "NATTEN only supports CUDA 12.0 and above." | ||
| if CUDA_VERSION >= [12, 0] and IS_WINDOWS: | ||
| print( | ||
| "WARNING: Torch cmake will likely fail on Windows with CUDA 12.X. " | ||
| "Please refer to NATTEN documentation to read more about the issue " | ||
| "and how to get around it until the issue is fixed in torch." | ||
| ) | ||
| # Now this is more explicit | ||
| AG_POLICY_FINE = AG_POLICY_DEFAULT | _tune_ag_policy(_AG_POLICY_TUNABLES, 2) | ||
| AG_POLICY_COARSE = AG_POLICY_DEFAULT | _tune_ag_policy(_AG_POLICY_TUNABLES, 0.5) |
There was a problem hiding this comment.
I love this refactor, but can we eliminate the concept of consts? Today we're doing 1 file for the FMHAs, but this can change (and very soon will). Why don't we just merge them? I'm okay with them being set to 2 in the FINE.
There was a problem hiding this comment.
Yeah, honestly I think _tune_ag_policy can handle that entirely. Not to just get rid of AG_POLICY_{FINE,COARSE} but we can probably do this even better.
| # Note: Union operator means last key wins. | ||
| def _tune_ag_policy(policy: dict, scale : float) -> dict: | ||
| for key in policy: | ||
| policy[key] = int(policy[key] * scale) |
There was a problem hiding this comment.
| policy[key] = int(policy[key] * scale) | |
| policy[key] = max(1, int(policy[key] * scale)) |
Needs a clip so the 1s don't end up 0s when scale is < 1?
There was a problem hiding this comment.
Added to local stage. Good catch
| # Also because we want CMake to build everything elsewhere, otherwise pypi will package | ||
| # build files. | ||
| build_dir = self.build_lib if NATTEN_BUILD_DIR is None else NATTEN_BUILD_DIR | ||
| if env['BUILD_DIR'] is not None: |
There was a problem hiding this comment.
Note: Current local stage has a context manager here. Commenting to remember to discuss. self.build_lib never actually gets called?
| so_path_final = f"{self.build_lib}/{output_binary_name}" | ||
| if not os.path.exists(so_dir): | ||
| os.makedirs(so_dir) | ||
| so_dir_final = os.path.join(self.build_lib, |
There was a problem hiding this comment.
I think so_dir_final should be removed and then the os.makedirs line should just be os.makedirs(os.path.dirname(so_path_final)). I think this is clearer
| ################## | ||
| # Helper functions | ||
| ################## | ||
| def _get_torch_cuda_version() -> float: |
There was a problem hiding this comment.
Note: Handle via cmake? Currently we rely on torch but we don't need to.
| "torch >= 2.7", | ||
| "cmake >= 4.0", | ||
| "torch >= 2.8", | ||
| "cmake >= 3.2", |
There was a problem hiding this comment.
Note to review our versioning here
As discussed, we're splitting PR #273 up.
This PR is isolated to the changes in
pyproject.tomlandsetup.pythat do not involveuv. (Which is most of that PR)Tested on EndeavourOS with 4080S (sm89)
Changes
-
pyproject.toml:torchnoted in the dependencies (not exactly needed, at least now ¯\_(ツ)_/¯)torchversion (NEEDS CHECK)-
src/natten/profiler.py-
setup.pyenvcreated to handle user environment variables.SUPPORTED_GPU_ARCHandCAT2ARCHglobals that are at the top of the file.SUPPORTED_GPU_ARCH: called by_arch_list_to_cmake_tagsto definearch_listas beforeCAT2ARCH: called inautogen_kernel_instantiationswhere we now loop through these keys. Previously we were defining the category options in this function with hard coded values. We now have the function_category_generatorthat takes in the arch and generates thefnaandfmhavariations of the forward and backward functions with the associated architecture name. If this 4 entry dict pattern stays then we don't need to do anything except editCAT2ARCHto add those architectures (in thesetup.pyfile at least...).NUM_SPLITSmodifed to avoid redundancy. Function_tune_ag_policyadded to help generate this. Globals_AG_POLICY_TUNABLESare intended to be tuned. Currently we either double or halve these values, associatedAG_POLICY_FINEandAG_POLICY_COURSErespectively.AG_POLICY_*dictionaries are modified using the union operator to ensure proper policy is implemented and dupes avoided.if-raisepatternTODO:
-
torchversionHas some inconsistencies that we need to check. In my
setup.pywe have a minimum version of2.6, but in current setup.py the version is2.5. This also doesn't match pyproject.toml. Which version do we support?- CUDA version & detection
We had discussed doing a different process for detecting this through
cmake. Do we want that in this version? I think we save that for a different PR since we are probably ready to merge this now and that will help us avoid merge conflicts.- Minor
I can do this in next PR
tmp_dirnear the top. I think we should place a context manager underBuildExtension.build_extensionsto handle this. Will be a bit cleaner and clarify the scope of this directory.NATTEN_BUILD_DIRwill never beNone. So why don't we just wrap in context manager and get rid of the possibility of usingself.build_libasBUILD_DIR?