Skip to content

Commit bffc64b

Browse files
committed
Updates and lint
1 parent bd86ac7 commit bffc64b

File tree

4 files changed

+56
-100
lines changed

4 files changed

+56
-100
lines changed

.github/workflows/_build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ jobs:
170170
export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2)
171171
export NVCC_THREADS=2
172172
export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
173-
export DEEP_GEMM_NO_LOCAL_VERSION=${{ inputs.use-local-version && 'FALSE' || 'TRUE' }}
173+
export DG_NO_LOCAL_VERSION=${{ inputs.use-local-version && '0' || '1' }}
174174
175175
# 5h timeout since GH allows max 6h and we want some buffer
176176
EXIT_CODE=0

.github/workflows/publish.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
# Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
4242
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
4343
os: [ubuntu-22.04]
44-
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
44+
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
4545
torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"]
4646
cuda-version: ["12.9.1"]
4747
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
@@ -83,8 +83,8 @@ jobs:
8383
pip install torch --index-url https://download.pytorch.org/whl/cpu
8484
- name: Build core package
8585
env:
86-
DEEP_GEMM_NO_LOCAL_VERSION: "TRUE"
87-
DEEP_GEMM_SKIP_CUDA_BUILD: "TRUE"
86+
DG_NO_LOCAL_VERSION: "1"
87+
DG_SKIP_CUDA_BUILD: "1"
8888
run: |
8989
python setup.py sdist --dist-dir=dist
9090
- name: Deploy

deep_gemm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ def _find_cuda_home() -> str:
7373
_find_cuda_home() # CUDA home
7474
)
7575

76-
__version__ = "2.0.0"
76+
__version__ = '2.0.0'

setup.py

Lines changed: 51 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,31 @@
11
import ast
22
import os
33
import re
4-
import setuptools
54
import shutil
5+
import setuptools
66
import subprocess
77
import sys
8-
import urllib
98
import torch
109
import platform
10+
import urllib
11+
import urllib.error
12+
import urllib.request
1113
from setuptools import find_packages
1214
from setuptools.command.build_py import build_py
13-
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
14-
from pathlib import Path
1515
from packaging.version import parse
16+
from pathlib import Path
17+
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
1618
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
1719

18-
SKIP_CUDA_BUILD = os.getenv("DEEP_GEMM_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
19-
NO_LOCAL_VERSION = os.getenv("DEEP_GEMM_NO_LOCAL_VERSION", "FALSE") == "TRUE"
20-
FORCE_BUILD = os.getenv("DEEP_GEMM_FORCE_BUILD", "FALSE") == "TRUE"
2120

22-
BASE_WHEEL_URL = (
23-
"https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}"
24-
)
25-
PACKAGE_NAME = "deep_gemm"
26-
27-
current_dir = os.path.dirname(os.path.realpath(__file__))
21+
# Compiler flags
2822
cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations',
2923
f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}']
24+
if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')):
25+
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
26+
27+
# Sources
28+
current_dir = os.path.dirname(os.path.realpath(__file__))
3029
sources = ['csrc/python_api.cpp']
3130
build_include_dirs = [
3231
f'{CUDA_HOME}/include',
@@ -45,67 +44,60 @@
4544
'third-party/cutlass/include/cutlass',
4645
]
4746

48-
# Use runtime API
49-
if int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')):
50-
cxx_flags.append('-DDG_JIT_USE_RUNTIME_API')
47+
# Release
48+
base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}'
49+
5150

5251
def get_package_version():
53-
with open(Path(current_dir) / "deep_gemm" / "__init__.py", "r") as f:
54-
version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
52+
with open(Path(current_dir) / 'deep_gemm' / '__init__.py', 'r') as f:
53+
version_match = re.search(r'^__version__\s*=\s*(.*)$', f.read(), re.MULTILINE)
5554
public_version = ast.literal_eval(version_match.group(1))
56-
revision = ""
55+
revision = ''
5756

58-
if not NO_LOCAL_VERSION:
57+
if int(os.getenv('DG_NO_LOCAL_VERSION', '0')) == 0:
58+
# noinspection PyBroadException
5959
try:
60-
cmd = ["git", "rev-parse", "--short", "HEAD"]
61-
revision = "+" + subprocess.check_output(cmd).decode("ascii").rstrip()
60+
cmd = ['git', 'rev-parse', '--short', 'HEAD']
61+
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
6262
except:
63-
revision = ""
63+
revision = ''
64+
return f'{public_version}{revision}'
6465

65-
return f"{public_version}{revision}"
6666

6767
def get_platform():
68-
"""
69-
Returns the platform name as used in wheel filenames.
70-
"""
71-
if sys.platform.startswith("linux"):
72-
return f"linux_{platform.uname().machine}"
73-
elif sys.platform == "darwin":
74-
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
75-
return f"macosx_{mac_version}_x86_64"
76-
elif sys.platform == "win32":
77-
return "win_amd64"
68+
if sys.platform.startswith('linux'):
69+
return f'linux_{platform.uname().machine}'
7870
else:
79-
raise ValueError("Unsupported platform: {}".format(sys.platform))
71+
raise ValueError('Unsupported platform: {}'.format(sys.platform))
72+
8073

8174
def get_wheel_url():
82-
torch_version_raw = parse(torch.__version__)
83-
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
75+
torch_version = parse(torch.__version__)
76+
torch_version = f'{torch_version.major}.{torch_version.minor}'
77+
python_version = f'cp{sys.version_info.major}{sys.version_info.minor}'
8478
platform_name = get_platform()
85-
grouped_gemm_version = get_package_version()
86-
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
79+
deep_gemm_version = get_package_version()
8780
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
8881

8982
# Determine the version numbers that will be used to determine the correct wheel
9083
# We're using the CUDA version used to build torch, not the one currently installed
91-
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
92-
torch_cuda_version = parse(torch.version.cuda)
93-
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
94-
# to save CI time. Minor versions should be compatible.
95-
torch_cuda_version = (
96-
parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
97-
)
98-
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
99-
cuda_version = f"{torch_cuda_version.major}"
84+
cuda_version = parse(torch.version.cuda)
85+
cuda_version = f'{cuda_version.major}'
10086

10187
# Determine wheel URL based on CUDA version, torch version, python version and OS
102-
wheel_filename = f"{PACKAGE_NAME}-{grouped_gemm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
88+
wheel_filename = f'deep_gemm-{deep_gemm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl'
89+
wheel_url = base_wheel_url.format(tag_name=f'v{deep_gemm_version}', wheel_name=wheel_filename)
90+
return wheel_url, wheel_filename
10391

104-
wheel_url = BASE_WHEEL_URL.format(
105-
tag_name=f"v{grouped_gemm_version}", wheel_name=wheel_filename
106-
)
10792

108-
return wheel_url, wheel_filename
93+
def get_ext_modules():
94+
if os.getenv('DG_SKIP_CUDA_BUILD', '0') != 0:
95+
return []
96+
97+
return [CUDAExtension(name='deep_gemm_cpp',
98+
sources=sources,
99+
include_dirs=build_include_dirs,)]
100+
109101

110102
class CustomBuildPy(build_py):
111103
def run(self):
@@ -145,60 +137,31 @@ def prepare_includes(self):
145137
# Copy the directory
146138
shutil.copytree(src_dir, dst_dir)
147139

148-
if not SKIP_CUDA_BUILD:
149-
ext_modules = [
150-
CUDAExtension(
151-
name="deep_gemm_cpp",
152-
sources=sources,
153-
include_dirs=build_include_dirs,
154-
)
155-
]
156-
else:
157-
ext_modules = []
158140

159141
class CachedWheelsCommand(_bdist_wheel):
160-
"""
161-
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
162-
find an existing wheel (which is currently the case for all grouped gemm installs). We use
163-
the environment parameters to detect whether there is already a pre-built version of a compatible
164-
wheel available and short-circuits the standard full build pipeline.
165-
"""
166-
167142
def run(self):
168-
if FORCE_BUILD:
143+
if int(os.getenv('DG_FORCE_BUILD', '0')) != 0:
169144
return super().run()
170145

171146
wheel_url, wheel_filename = get_wheel_url()
172-
print("Guessing wheel URL: ", wheel_url)
147+
print(f'Try to download wheel from URL: {wheel_url}')
173148
try:
174149
urllib.request.urlretrieve(wheel_url, wheel_filename)
175150

176151
# Make the archive
177-
# Lifted from the root wheel processing command
178-
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
179152
if not os.path.exists(self.dist_dir):
180153
os.makedirs(self.dist_dir)
181-
182154
impl_tag, abi_tag, plat_tag = self.get_tag()
183-
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
184-
185-
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
186-
print("Raw wheel path", wheel_path)
155+
archive_basename = f'{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}'
156+
wheel_path = os.path.join(self.dist_dir, archive_basename + '.whl')
187157
os.rename(wheel_filename, wheel_path)
188158
except (urllib.error.HTTPError, urllib.error.URLError):
189-
print("Precompiled wheel not found. Building from source...")
159+
print('Precompiled wheel not found. Building from source...')
190160
# If the wheel could not be downloaded, build from source
191161
super().run()
192162

193163

194164
if __name__ == '__main__':
195-
# noinspection PyBroadException
196-
try:
197-
cmd = ['git', 'rev-parse', '--short', 'HEAD']
198-
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
199-
except:
200-
revision = ''
201-
202165
# noinspection PyTypeChecker
203166
setuptools.setup(
204167
name='deep_gemm',
@@ -211,14 +174,7 @@ def run(self):
211174
'include/cutlass/**/*',
212175
]
213176
},
214-
ext_modules=[
215-
CUDAExtension(name='deep_gemm_cpp',
216-
sources=sources,
217-
include_dirs=build_include_dirs,
218-
libraries=build_libraries,
219-
library_dirs=build_library_dirs,
220-
extra_compile_args=cxx_flags)
221-
],
177+
ext_modules=get_ext_modules(),
222178
zip_safe=False,
223179
cmdclass={
224180
'build_py': CustomBuildPy,

0 commit comments

Comments
 (0)