Skip to content

Commit

Permalink
Make install work on CUDA < 11
Browse files Browse the repository at this point in the history
  • Loading branch information
sxyu committed Oct 20, 2021
1 parent eccfb9f commit 578320d
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from setuptools import setup
import os
import os.path as osp
import warnings

from torch.utils.cpp_extension import BuildExtension, CUDAExtension

Expand All @@ -10,6 +12,26 @@

CUDA_FLAGS = []
INSTALL_REQUIREMENTS = []
include_dirs = [osp.join(ROOT_DIR, "svox2", "csrc", "include")]

# From PyTorch3D
cub_home = os.environ.get("CUB_HOME", None)
if cub_home is None:
prefix = os.environ.get("CONDA_PREFIX", None)
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
cub_home = prefix + "/include"

if cub_home is None:
warnings.warn(
"The environment variable `CUB_HOME` was not found."
"Installation will fail if your system CUDA toolkit version is less than 11."
"NVIDIA CUB can be downloaded "
"from `https://github.com/NVIDIA/cub/releases`. You can unpack "
"it to a location of your choice and set the environment variable "
"`CUB_HOME` to the folder containing the `CMakeListst.txt` file."
)
else:
include_dirs.append(os.path.realpath(cub_home).replace("\\ ", " "))

try:
ext_modules = [
Expand All @@ -20,8 +42,8 @@
'svox2/csrc/misc_kernel.cu',
'svox2/csrc/loss_kernel.cu',
'svox2/csrc/optim_kernel.cu',
], include_dirs=[osp.join(ROOT_DIR, "svox2", "csrc", "include"),],
optional=True),
], include_dirs=include_dirs,
optional=False),
]
except:
import warnings
Expand Down

0 comments on commit 578320d

Please sign in to comment.