diff --git a/setup.py b/setup.py index d93b78d8..4d305e85 100644 --- a/setup.py +++ b/setup.py @@ -22,22 +22,27 @@ def get_ext(): extra_objects = [] library_dirs = library_paths() libraries = ["c10", "torch", "torch_cpu", "torch_python"] - extra_link_args = [] + for i in library_dirs: + extra_link_args = ['-Wl,-rpath,' + i] - dipu_root = _getenv_or_die("DIPU_ROOT") + dipu_path = _getenv_or_die("DIPU_PATH") diopi_path = _getenv_or_die("DIOPI_PATH") + torch_dipu_path = os.path.join(dipu_path, 'torch_dipu') + dipu_lib_path = torch_dipu_path + + extra_link_args += ['-Wl,-rpath,' + dipu_lib_path] vendor_include_dirs = os.getenv("VENDOR_INCLUDE_DIRS") nccl_include_dirs = os.getenv("NCCL_INCLUDE_DIRS") # nv所需 system_include_dirs += [ - dipu_root, - os.path.join(dipu_root, "dist/include"), - os.path.join(diopi_path, "include"), + torch_dipu_path, + os.path.join(torch_dipu_path, "dist/include"), + os.path.join(diopi_path, "proto/include"), ] if vendor_include_dirs: system_include_dirs.append(vendor_include_dirs) if nccl_include_dirs: system_include_dirs.append(nccl_include_dirs) - library_dirs += [dipu_root] + library_dirs += [dipu_lib_path] libraries += ["torch_dipu"] extra_compile_args = {"cxx": []}