diff --git a/setup.py b/setup.py index ca9ceb2fdf5..60351a67909 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,9 @@ def write_version_file(): with open(version_path, 'w') as f: f.write("__version__ = '{}'\n".format(version)) f.write("git_version = {}\n".format(repr(sha))) + f.write("from torchvision import _C\n") + f.write("if hasattr(_C, 'CUDA_VERSION'):\n") + f.write(" cuda = _C.CUDA_VERSION\n") write_version_file() diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 82ba966dd5a..4af0c08a8aa 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -33,3 +33,31 @@ def get_image_backend(): Gets the name of the package used to load images """ return _image_backend + + +def _check_cuda_matches(): + """ + Make sure that CUDA versions match between the pytorch install and torchvision install + """ + import torch + from torchvision import _C + if hasattr(_C, "CUDA_VERSION") and torch.version.cuda is not None: + tv_version = str(_C.CUDA_VERSION) + if int(tv_version) < 10000: + tv_major = int(tv_version[0]) + tv_minor = int(tv_version[2]) + else: + tv_major = int(tv_version[0:2]) + tv_minor = int(tv_version[3]) + t_version = torch.version.cuda + t_version = t_version.split('.') + t_major = int(t_version[0]) + t_minor = int(t_version[1]) + if t_major != tv_major or t_minor != tv_minor: + raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. " + "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " + "Please reinstall the torchvision that matches your PyTorch install." + .format(t_major, t_minor, tv_major, tv_minor)) + + +_check_cuda_matches() diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index ca62dcf5139..4777d70a38b 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -2,10 +2,17 @@ #include "ROIPool.h" #include "nms.h" +#ifdef WITH_CUDA +#include +#endif + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("nms", &nms, "non-maximum suppression"); m.def("roi_align_forward", &ROIAlign_forward, "ROIAlign_forward"); m.def("roi_align_backward", &ROIAlign_backward, "ROIAlign_backward"); m.def("roi_pool_forward", &ROIPool_forward, "ROIPool_forward"); m.def("roi_pool_backward", &ROIPool_backward, "ROIPool_backward"); +#ifdef WITH_CUDA + m.attr("CUDA_VERSION") = CUDA_VERSION; +#endif }