1
1
import ast
2
2
import os
3
3
import re
4
- import setuptools
5
4
import shutil
5
+ import setuptools
6
6
import subprocess
7
7
import sys
8
- import urllib
9
8
import torch
10
9
import platform
10
+ import urllib
11
+ import urllib .error
12
+ import urllib .request
11
13
from setuptools import find_packages
12
14
from setuptools .command .build_py import build_py
13
- from torch .utils .cpp_extension import CUDAExtension , CUDA_HOME
14
- from pathlib import Path
15
15
from packaging .version import parse
16
+ from pathlib import Path
17
+ from torch .utils .cpp_extension import CUDAExtension , CUDA_HOME
16
18
from wheel .bdist_wheel import bdist_wheel as _bdist_wheel
17
19
18
- SKIP_CUDA_BUILD = os .getenv ("DEEP_GEMM_SKIP_CUDA_BUILD" , "FALSE" ) == "TRUE"
19
- NO_LOCAL_VERSION = os .getenv ("DEEP_GEMM_NO_LOCAL_VERSION" , "FALSE" ) == "TRUE"
20
- FORCE_BUILD = os .getenv ("DEEP_GEMM_FORCE_BUILD" , "FALSE" ) == "TRUE"
21
20
22
- BASE_WHEEL_URL = (
23
- "https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}"
24
- )
25
- PACKAGE_NAME = "deep_gemm"
26
-
27
- current_dir = os .path .dirname (os .path .realpath (__file__ ))
21
+ # Compiler flags
28
22
cxx_flags = ['-std=c++17' , '-O3' , '-fPIC' , '-Wno-psabi' , '-Wno-deprecated-declarations' ,
29
23
f'-D_GLIBCXX_USE_CXX11_ABI={ int (torch .compiled_with_cxx11_abi ())} ' ]
24
+ if int (os .environ .get ('DG_JIT_USE_RUNTIME_API' , '0' )):
25
+ cxx_flags .append ('-DDG_JIT_USE_RUNTIME_API' )
26
+
27
+ # Sources
28
+ current_dir = os .path .dirname (os .path .realpath (__file__ ))
30
29
sources = ['csrc/python_api.cpp' ]
31
30
build_include_dirs = [
32
31
f'{ CUDA_HOME } /include' ,
45
44
'third-party/cutlass/include/cutlass' ,
46
45
]
47
46
48
- # Use runtime API
49
- if int ( os . environ . get ( 'DG_JIT_USE_RUNTIME_API' , '0' )):
50
- cxx_flags . append ( '-DDG_JIT_USE_RUNTIME_API' )
47
+ # Release
48
+ base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}'
49
+
51
50
52
51
def get_package_version ():
53
- with open (Path (current_dir ) / " deep_gemm" / " __init__.py" , "r" ) as f :
54
- version_match = re .search (r" ^__version__\s*=\s*(.*)$" , f .read (), re .MULTILINE )
52
+ with open (Path (current_dir ) / ' deep_gemm' / ' __init__.py' , 'r' ) as f :
53
+ version_match = re .search (r' ^__version__\s*=\s*(.*)$' , f .read (), re .MULTILINE )
55
54
public_version = ast .literal_eval (version_match .group (1 ))
56
- revision = ""
55
+ revision = ''
57
56
58
- if not NO_LOCAL_VERSION :
57
+ if int (os .getenv ('DG_NO_LOCAL_VERSION' , '0' )) == 0 :
58
+ # noinspection PyBroadException
59
59
try :
60
- cmd = [" git" , " rev-parse" , " --short" , " HEAD" ]
61
- revision = "+" + subprocess .check_output (cmd ).decode (" ascii" ).rstrip ()
60
+ cmd = [' git' , ' rev-parse' , ' --short' , ' HEAD' ]
61
+ revision = '+' + subprocess .check_output (cmd ).decode (' ascii' ).rstrip ()
62
62
except :
63
- revision = ""
63
+ revision = ''
64
+ return f'{ public_version } { revision } '
64
65
65
- return f"{ public_version } { revision } "
66
66
67
67
def get_platform ():
68
- """
69
- Returns the platform name as used in wheel filenames.
70
- """
71
- if sys .platform .startswith ("linux" ):
72
- return f"linux_{ platform .uname ().machine } "
73
- elif sys .platform == "darwin" :
74
- mac_version = "." .join (platform .mac_ver ()[0 ].split ("." )[:2 ])
75
- return f"macosx_{ mac_version } _x86_64"
76
- elif sys .platform == "win32" :
77
- return "win_amd64"
68
+ if sys .platform .startswith ('linux' ):
69
+ return f'linux_{ platform .uname ().machine } '
78
70
else :
79
- raise ValueError ("Unsupported platform: {}" .format (sys .platform ))
71
+ raise ValueError ('Unsupported platform: {}' .format (sys .platform ))
72
+
80
73
81
74
def get_wheel_url ():
82
- torch_version_raw = parse (torch .__version__ )
83
- python_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
75
+ torch_version = parse (torch .__version__ )
76
+ torch_version = f'{ torch_version .major } .{ torch_version .minor } '
77
+ python_version = f'cp{ sys .version_info .major } { sys .version_info .minor } '
84
78
platform_name = get_platform ()
85
- grouped_gemm_version = get_package_version ()
86
- torch_version = f"{ torch_version_raw .major } .{ torch_version_raw .minor } "
79
+ deep_gemm_version = get_package_version ()
87
80
cxx11_abi = str (torch ._C ._GLIBCXX_USE_CXX11_ABI ).upper ()
88
81
89
82
# Determine the version numbers that will be used to determine the correct wheel
90
83
# We're using the CUDA version used to build torch, not the one currently installed
91
- # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
92
- torch_cuda_version = parse (torch .version .cuda )
93
- # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
94
- # to save CI time. Minor versions should be compatible.
95
- torch_cuda_version = (
96
- parse ("11.8" ) if torch_cuda_version .major == 11 else parse ("12.3" )
97
- )
98
- # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
99
- cuda_version = f"{ torch_cuda_version .major } "
84
+ cuda_version = parse (torch .version .cuda )
85
+ cuda_version = f'{ cuda_version .major } '
100
86
101
87
# Determine wheel URL based on CUDA version, torch version, python version and OS
102
- wheel_filename = f"{ PACKAGE_NAME } -{ grouped_gemm_version } +cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl"
88
+ wheel_filename = f'deep_gemm-{ deep_gemm_version } +cu{ cuda_version } torch{ torch_version } cxx11abi{ cxx11_abi } -{ python_version } -{ python_version } -{ platform_name } .whl'
89
+ wheel_url = base_wheel_url .format (tag_name = f'v{ deep_gemm_version } ' , wheel_name = wheel_filename )
90
+ return wheel_url , wheel_filename
103
91
104
- wheel_url = BASE_WHEEL_URL .format (
105
- tag_name = f"v{ grouped_gemm_version } " , wheel_name = wheel_filename
106
- )
107
92
108
- return wheel_url , wheel_filename
93
+ def get_ext_modules ():
94
+ if os .getenv ('DG_SKIP_CUDA_BUILD' , '0' ) != 0 :
95
+ return []
96
+
97
+ return [CUDAExtension (name = 'deep_gemm_cpp' ,
98
+ sources = sources ,
99
+ include_dirs = build_include_dirs ,)]
100
+
109
101
110
102
class CustomBuildPy (build_py ):
111
103
def run (self ):
@@ -145,60 +137,31 @@ def prepare_includes(self):
145
137
# Copy the directory
146
138
shutil .copytree (src_dir , dst_dir )
147
139
148
- if not SKIP_CUDA_BUILD :
149
- ext_modules = [
150
- CUDAExtension (
151
- name = "deep_gemm_cpp" ,
152
- sources = sources ,
153
- include_dirs = build_include_dirs ,
154
- )
155
- ]
156
- else :
157
- ext_modules = []
158
140
159
141
class CachedWheelsCommand (_bdist_wheel ):
160
- """
161
- The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
162
- find an existing wheel (which is currently the case for all grouped gemm installs). We use
163
- the environment parameters to detect whether there is already a pre-built version of a compatible
164
- wheel available and short-circuits the standard full build pipeline.
165
- """
166
-
167
142
def run (self ):
168
- if FORCE_BUILD :
143
+ if int ( os . getenv ( 'DG_FORCE_BUILD' , '0' )) != 0 :
169
144
return super ().run ()
170
145
171
146
wheel_url , wheel_filename = get_wheel_url ()
172
- print ("Guessing wheel URL: " , wheel_url )
147
+ print (f'Try to download wheel from URL: { wheel_url } ' )
173
148
try :
174
149
urllib .request .urlretrieve (wheel_url , wheel_filename )
175
150
176
151
# Make the archive
177
- # Lifted from the root wheel processing command
178
- # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
179
152
if not os .path .exists (self .dist_dir ):
180
153
os .makedirs (self .dist_dir )
181
-
182
154
impl_tag , abi_tag , plat_tag = self .get_tag ()
183
- archive_basename = f"{ self .wheel_dist_name } -{ impl_tag } -{ abi_tag } -{ plat_tag } "
184
-
185
- wheel_path = os .path .join (self .dist_dir , archive_basename + ".whl" )
186
- print ("Raw wheel path" , wheel_path )
155
+ archive_basename = f'{ self .wheel_dist_name } -{ impl_tag } -{ abi_tag } -{ plat_tag } '
156
+ wheel_path = os .path .join (self .dist_dir , archive_basename + '.whl' )
187
157
os .rename (wheel_filename , wheel_path )
188
158
except (urllib .error .HTTPError , urllib .error .URLError ):
189
- print (" Precompiled wheel not found. Building from source..." )
159
+ print (' Precompiled wheel not found. Building from source...' )
190
160
# If the wheel could not be downloaded, build from source
191
161
super ().run ()
192
162
193
163
194
164
if __name__ == '__main__' :
195
- # noinspection PyBroadException
196
- try :
197
- cmd = ['git' , 'rev-parse' , '--short' , 'HEAD' ]
198
- revision = '+' + subprocess .check_output (cmd ).decode ('ascii' ).rstrip ()
199
- except :
200
- revision = ''
201
-
202
165
# noinspection PyTypeChecker
203
166
setuptools .setup (
204
167
name = 'deep_gemm' ,
@@ -211,14 +174,7 @@ def run(self):
211
174
'include/cutlass/**/*' ,
212
175
]
213
176
},
214
- ext_modules = [
215
- CUDAExtension (name = 'deep_gemm_cpp' ,
216
- sources = sources ,
217
- include_dirs = build_include_dirs ,
218
- libraries = build_libraries ,
219
- library_dirs = build_library_dirs ,
220
- extra_compile_args = cxx_flags )
221
- ],
177
+ ext_modules = get_ext_modules (),
222
178
zip_safe = False ,
223
179
cmdclass = {
224
180
'build_py' : CustomBuildPy ,
0 commit comments