diff --git a/setup.py b/setup.py index 3df847dd5b..c6a3a7a51d 100644 --- a/setup.py +++ b/setup.py @@ -158,9 +158,11 @@ def get_extensions(): op_files.remove('./mmcv/ops/csrc/pytorch/cuda/bias_act_cuda.cu') cuda_args = os.getenv('MMCV_CUDA_ARGS') extra_compile_args = { - 'nvcc': [cuda_args, '-std=c++14'] if cuda_args else ['-std=c++14'], + 'nvcc': ['-std=c++14'], 'cxx': ['-std=c++14'], } + if cuda_args: + extra_compile_args['nvcc'].extend(cuda_args.split()) if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': define_macros += [('MMCV_WITH_CUDA', None)] extra_compile_args['nvcc'] += [ @@ -260,7 +262,10 @@ def get_extensions(): define_macros += [('MMCV_WITH_HIP', None)] define_macros += [('MMCV_WITH_CUDA', None)] cuda_args = os.getenv('MMCV_CUDA_ARGS') - extra_compile_args['nvcc'] = [cuda_args] if cuda_args else [] + if cuda_args: + extra_compile_args['nvcc'] = cuda_args.split() + else: + extra_compile_args['nvcc'] = [] op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cuda/*.cu') + \