From 17204c6dccf6210753bc8c0ca4c92278b60719c9 Mon Sep 17 00:00:00 2001 From: Raul Date: Tue, 14 Nov 2023 17:32:59 +0100 Subject: [PATCH] Set C++17 for latest pytorch versions. Add flags for CUDA 12 and 11.8 (#641) * Set C++17 for latest pytorch versions. Add flags for CUDA 12 and 11.8 * Update setup.py * remove import subprocess * more robust way to compare version --------- Co-authored-by: Jinze Xue --- setup.py | 41 ++++++++++++----------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/setup.py b/setup.py index ac919404a..108cd730f 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ import os -import subprocess +from packaging import version from setuptools import setup, find_packages from distutils import log import sys @@ -24,32 +24,6 @@ long_description = fh.read() -def maybe_download_cub(): - import torch - dirs = torch.utils.cpp_extension.include_paths(cuda=True) - for d in dirs: - cubdir = os.path.join(d, 'cub') - log.info(f'Searching for cub at {cubdir}...') - if os.path.isdir(cubdir): - log.info(f'Found cub in {cubdir}') - return [] - # if no cub, download it to include dir from github - if not os.path.isdir('./include/cub'): - if not os.path.exists('./include'): - os.makedirs('include') - commands = """ - echo "Downloading CUB library"; - wget -q https://github.com/NVIDIA/cub/archive/refs/tags/1.11.0.zip; - unzip -q 1.11.0.zip -d include; - mv include/cub-1.11.0/cub include; - echo "Removing unnecessary files"; - rm 1.11.0.zip; - rm -rf include/cub-1.11.0; - """ - subprocess.run(commands, shell=True, check=True, universal_newlines=True) - return [os.path.abspath("./include")] - - def cuda_extension(build_all=False): import torch from torch.utils.cpp_extension import CUDAExtension @@ -87,15 +61,24 @@ def cuda_extension(build_all=False): nvcc_args.append("-gencode=arch=compute_80,code=sm_80") if cuda_version >= 11.1: nvcc_args.append("-gencode=arch=compute_86,code=sm_86") + if cuda_version >= 11.8: + nvcc_args.append("-gencode=arch=compute_89,code=sm_89") + if cuda_version >= 12.0: + nvcc_args.append("-gencode=arch=compute_90,code=sm_90") + print("nvcc_args: ", nvcc_args) print('-' * 75) - include_dirs = [*maybe_download_cub(), os.path.abspath("torchani/cuaev/")] + include_dirs = [os.path.abspath("torchani/cuaev/")] + # Update C++ standard based on PyTorch version + pytorch_version = version.parse(torch.__version__) + cxx_args = ['-std=c++17'] if pytorch_version >= version.parse("2.1.0") else ['-std=c++14'] + return CUDAExtension( name='torchani.cuaev', pkg='torchani.cuaev', sources=["torchani/cuaev/cuaev.cpp", "torchani/cuaev/aev.cu"], include_dirs=include_dirs, - extra_compile_args={'cxx': ['-std=c++14'], 'nvcc': nvcc_args}) + extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) def cuaev_kwargs():