Skip to content

Commit 7d89d3c

Browse files
committed
[Feat] Add code format check (#53)
1. add .pre-commit-config i. include ruff-check, yapf, codespell ii. Keep the code style checking content consistent with the settings in pyproject iii. Check code style when git commit and in ci.yml
1 parent 2935548 commit 7d89d3c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+969
-700
lines changed

.github/workflows/ci.yml

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,44 @@ name: hello-self-hosted
22

33
on:
44
push:
5-
branches: ['**'] # 所有分支
6-
tags: ['**'] # 所有标签推送也触发,可按需删掉
5+
branches: ['**'] # All branches
6+
tags: ['**'] # All tags push will also trigger, can be removed as needed
77
pull_request:
8-
branches: ['**'] # 所有分支的 PR 都触发 CI
8+
branches: ['**'] # All branch PRs trigger CI
99

1010
jobs:
11+
pre-commit:
12+
runs-on: ubuntu-latest
13+
14+
steps:
15+
- name: Checkout code
16+
uses: actions/checkout@v3
17+
with:
18+
fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch
19+
20+
- name: Set up Python
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: "3.10" # At least 3.9 to avoid pre-commit-hooks reporting Python version too low
24+
25+
- name: Run pre-commit
26+
# Official pre-commit Action, automatically:
27+
# 1) Install pre-commit
28+
# 2) Generate virtual environment according to .pre-commit-config.yaml configuration
29+
# 3) Execute pre-commit run
30+
uses: pre-commit/[email protected]
31+
with:
32+
# Run on all files once, and print diff on failure
33+
extra_args: --all-files --show-diff-on-failure
34+
1135
tileops_test_0-1-6-post1:
36+
needs: pre-commit
1237
runs-on: [self-hosted, tile-ops]
1338
steps:
1439
- name: Checkout code
1540
uses: actions/checkout@v3
1641
with:
17-
fetch-depth: 0 # 相当于 GIT_STRATEGY: fetch
42+
fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch
1843

1944
- name: Setup & Run tests
2045
run: |
@@ -30,19 +55,20 @@ jobs:
3055

3156
- name: Upload artifacts
3257
uses: actions/upload-artifact@v4
33-
if: always() # 相当于 when: always
58+
if: always() # Equivalent to when: always
3459
with:
3560
name: tileops_test_0_1_6.log
3661
path: tileops_test_0_1_6.log
37-
retention-days: 7 # 相当于 expire_in: 1 week
62+
retention-days: 7 # Equivalent to expire_in: 1 week
3863

3964
tileops_test_nightly:
65+
needs: pre-commit
4066
runs-on: [self-hosted, tile-ops]
4167
steps:
4268
- name: Checkout code
4369
uses: actions/checkout@v3
4470
with:
45-
fetch-depth: 0 # 相当于 GIT_STRATEGY: fetch
71+
fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch
4672

4773
- name: Setup & Run tests
4874
run: |
@@ -58,20 +84,20 @@ jobs:
5884

5985
- name: Upload artifacts
6086
uses: actions/upload-artifact@v4
61-
if: always() # 相当于 when: always
87+
if: always() # Equivalent to when: always
6288
with:
6389
name: tileops_test_nightly.log
6490
path: tileops_test_nightly.log
65-
retention-days: 7 # 相当于 expire_in: 1 week
91+
retention-days: 7 # Equivalent to expire_in: 1 week
6692

6793
tileops_profile_nightly:
68-
needs: tileops_test_nightly
94+
needs: [pre-commit, tileops_test_nightly]
6995
runs-on: [self-hosted, tile-ops]
7096
steps:
7197
- name: Checkout code
7298
uses: actions/checkout@v3
7399
with:
74-
fetch-depth: 0 # 相当于 GIT_STRATEGY: fetch
100+
fetch-depth: 0 # Equivalent to GIT_STRATEGY: fetch
75101

76102
- name: Setup & Run tests
77103
run: |
@@ -91,5 +117,4 @@ jobs:
91117
with:
92118
name: profile_out
93119
path: profile_out/
94-
retention-days: 7
95-
120+
retention-days: 7

.pre-commit-config.yaml

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# See https://pre-commit.com for more information
2+
# See https://pre-commit.com/hooks.html for more hooks
3+
4+
ci:
5+
autofix_prs: false # Don't automatically push "fix" commits to PRs on pre-commit.ci
6+
autofix_commit_msg: "[Lint]: [pre-commit.ci] auto fixes [...]" # Auto-generated commit message template when autofix is enabled
7+
autoupdate_commit_msg: "[CI] [pre-commit.ci] autoupdate" # Commit message used when pre-commit.ci auto-updates hook versions
8+
autoupdate_schedule: monthly # Auto-update hook versions once per month
9+
10+
# Default stages that trigger these hooks
11+
default_stages: [pre-commit, pre-push, manual]
12+
13+
# Globally ignored directories: files under build/ and 3rdparty/ won't run any hooks
14+
exclude: '^(build|3rdparty)/.*$' # exclude build and 3rdparty directories
15+
16+
repos:
17+
# Basic general checks: symlinks, file size, merge conflicts, etc.
18+
- repo: https://github.com/pre-commit/pre-commit-hooks
19+
rev: v6.0.0
20+
hooks:
21+
- id: check-symlinks # Check if symlinks in the repository point to valid paths
22+
- id: destroyed-symlinks # Check for accidentally broken symlinks (e.g., symlink changed to regular file)
23+
# FIXME: enable these hooks
24+
# - id: trailing-whitespace # Check and clean up trailing whitespace (currently disabled)
25+
# - id: end-of-file-fixer # Ensure files end with exactly one newline (currently disabled)
26+
- id: check-added-large-files # Prevent adding overly large files to avoid bloating the repository
27+
- id: check-merge-conflict # Check for Git merge conflict markers <<<<<<< / ======= / >>>>>>>
28+
fail_fast: true # Fail immediately upon finding conflict markers
29+
# FIXME: enable these hooks
30+
# - id: check-executables-have-shebangs # Check if executable files have shebangs (not enabled)
31+
# - id: check-shebang-scripts-are-executable # Check if scripts with shebangs are executable (not enabled)
32+
- id: detect-private-key # Detect if private keys are accidentally committed (e.g., SSH/TLS private keys)
33+
- id: check-yaml # Validate syntax of all YAML files
34+
- id: check-toml # Validate TOML files (like pyproject.toml) syntax
35+
- id: check-ast # Parse Python files with AST to check for syntax errors
36+
fail_fast: true # Fail immediately upon finding syntax errors
37+
- id: debug-statements # Detect debugging statements in Python, such as pdb.set_trace()/breakpoint()
38+
- id: file-contents-sorter # Sort contents line by line in specified files, used here for spelling_wordlist
39+
args: [--ignore-case] # Ignore case when sorting
40+
files: ^docs/spelling_wordlist\.txt$ # Only apply to docs/spelling_wordlist.txt
41+
42+
# Use Ruff for Python static analysis / linting
43+
- repo: https://github.com/astral-sh/ruff-pre-commit
44+
rev: v0.14.3 # sync with requirements-lint.txt
45+
hooks:
46+
- id: ruff-check
47+
# Automatically fix issues that can be fixed, and return non-zero exit code when fixes are made to remind git add
48+
args: [--fix, --exit-non-zero-on-fix]
49+
50+
# Use yapf for Python code formatting
51+
- repo: https://github.com/google/yapf
52+
rev: v0.43.0 # sync with requirements-lint.txt
53+
hooks:
54+
- id: yapf
55+
name: yapf-multiproc-bugfix
56+
# yapf is not multiprocessing-safe, so we first run a "dummy" yapf on top/__init__.py
57+
# to serve as a warm-up/workaround for concurrency bugs
58+
args: [--in-place, top/__init__.py]
59+
always_run: true # Run every time, regardless of file changes
60+
pass_filenames: false # Don't pass file list as arguments, only work on top/__init__.py
61+
- id: yapf
62+
# Real global Python formatting: recursively format Python files in-place
63+
args: [--recursive, --in-place]
64+
65+
# Use codespell for English spell checking
66+
- repo: https://github.com/codespell-project/codespell
67+
rev: v2.4.1 # sync with requirements-lint.txt
68+
hooks:
69+
- id: codespell
70+
# Allow codespell to read configuration from pyproject.toml / setup.cfg, etc.
71+
additional_dependencies: [".[toml]"]
72+
# Exclude file types unsuitable for spell checking, such as C/C++/CUDA source, SVG, requirements lists, etc.
73+
exclude: |
74+
(?x)(
75+
^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$|
76+
^.+\.svg$|
77+
^.*\brequirements\b.*\.txt$
78+
)

benchmarks/benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def check_fn(self, fn, *inputs, atol=1e-2, rtol=1e-2, grad=True):
9696
elif not isinstance(outputs, tuple):
9797
raise ValueError(f"Unsupported output type: {type(outputs)}")
9898

99-
assert len(outputs) == len(outputs_ref), f"outputs: {len(outputs)} and outputs_ref: {len(outputs_ref)} have different size"
99+
assert len(outputs) == len(
100+
outputs_ref
101+
), f"outputs: {len(outputs)} and outputs_ref: {len(outputs_ref)} have different size"
100102
for i, (output, output_ref) in enumerate(zip(outputs, outputs_ref)):
101103
# print(f"outputs[{i}] max err: {(output - output_ref).abs().max()}")
102104
if output_ref is not None: # skip checking for None placeholders in ref

benchmarks/flash_attn/gqa.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def total_flops(self):
2626

2727
@property
2828
def total_memory(self):
29-
return 2 * self.batch * self.seq_len * self.dim * (self.heads + self.heads_kv) * self.dtype.itemsize
29+
return 2 * self.batch * self.seq_len * self.dim * (self.heads +
30+
self.heads_kv) * self.dtype.itemsize
3031

3132
def gen_inputs(self):
3233
Q = torch.randn(
@@ -38,11 +39,12 @@ def gen_inputs(self):
3839
return Q, K, V
3940

4041
def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
41-
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
42+
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
4243
k_bhsd = K.transpose(1, 2)
4344
v_bhsd = V.transpose(1, 2)
4445
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
45-
output_bhsd = F.scaled_dot_product_attention(q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal, enable_gqa=True)
46+
output_bhsd = F.scaled_dot_product_attention(
47+
q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal, enable_gqa=True)
4648
output = output_bhsd.transpose(1, 2).contiguous()
4749
return output, None # do not check lse
4850

@@ -68,7 +70,8 @@ def total_flops(self):
6870

6971
@property
7072
def total_memory(self):
71-
return self.batch * (3 * self.heads + 4 * self.heads_kv) * self.seq_len * self.dim * self.dtype.itemsize
73+
return self.batch * (3 * self.heads +
74+
4 * self.heads_kv) * self.seq_len * self.dim * self.dtype.itemsize
7275

7376
def gen_inputs(self):
7477
Q = torch.randn(
@@ -127,7 +130,7 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, O: torc
127130

128131

129132
class gqa_benchmark(Benchmark):
130-
133+
131134
def __init__(self, batch, heads, heads_kv, seq_len, dim, is_causal, dtype, grad=True):
132135
self.batch = batch
133136
self.heads = heads
@@ -138,8 +141,10 @@ def __init__(self, batch, heads, heads_kv, seq_len, dim, is_causal, dtype, grad=
138141
self.dtype = dtype
139142
self.grad = grad
140143

141-
self.gqa_fwd_bench = gqa_fwd_benchmark(batch, heads, heads_kv, seq_len, dim, is_causal, dtype)
142-
self.gqa_bwd_bench = gqa_bwd_benchmark(batch, heads, heads_kv, seq_len, dim, is_causal, dtype)
144+
self.gqa_fwd_bench = gqa_fwd_benchmark(batch, heads, heads_kv, seq_len, dim, is_causal,
145+
dtype)
146+
self.gqa_bwd_bench = gqa_bwd_benchmark(batch, heads, heads_kv, seq_len, dim, is_causal,
147+
dtype)
143148

144149
@property
145150
def total_flops(self):
@@ -148,14 +153,14 @@ def total_flops(self):
148153
@property
149154
def total_memory(self):
150155
return self.gqa_fwd_bench.total_memory + self.gqa_bwd_bench.total_memory
151-
156+
152157
def gen_inputs(self):
153158
if self.grad:
154159
Q, K, V, _, _, _ = self.gqa_bwd_bench.gen_inputs()
155160
return Q, K, V
156161
else:
157162
return self.gqa_fwd_bench.gen_inputs()
158-
163+
159164
def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
160165

161166
output = self.gqa_fwd_bench.ref_program(Q, K, V)[0]
@@ -165,4 +170,3 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
165170
loss = output.sum()
166171
loss.backward()
167172
return output, Q.grad, K.grad, V.grad
168-

benchmarks/flash_attn/mha.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ def gen_inputs(self):
3737
return Q, K, V
3838

3939
def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
40-
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
40+
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
4141
k_bhsd = K.transpose(1, 2)
4242
v_bhsd = V.transpose(1, 2)
4343
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
44-
output_bhsd = F.scaled_dot_product_attention(q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
44+
output_bhsd = F.scaled_dot_product_attention(
45+
q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
4546
output = output_bhsd.transpose(1, 2).contiguous()
4647
return output, None # do not check lse
4748

@@ -104,11 +105,12 @@ def gen_inputs(self):
104105

105106
def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, O: torch.Tensor,
106107
dO: torch.Tensor, lse: torch.Tensor):
107-
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
108+
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
108109
k_bhsd = K.transpose(1, 2)
109110
v_bhsd = V.transpose(1, 2)
110111
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
111-
output_bhsd = F.scaled_dot_product_attention(q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
112+
output_bhsd = F.scaled_dot_product_attention(
113+
q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
112114
output = output_bhsd.transpose(1, 2).contiguous()
113115

114116
output.backward(dO)
@@ -165,12 +167,17 @@ def gen_inputs(self):
165167

166168
return Q, K, V
167169

168-
def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, dO: torch.Tensor = None):
169-
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
170+
def ref_program(self,
171+
Q: torch.Tensor,
172+
K: torch.Tensor,
173+
V: torch.Tensor,
174+
dO: torch.Tensor = None):
175+
q_bhsd = Q.transpose(1, 2) # [B, H, S, D]
170176
k_bhsd = K.transpose(1, 2)
171177
v_bhsd = V.transpose(1, 2)
172178
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
173-
output_bhsd = F.scaled_dot_product_attention(q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
179+
output_bhsd = F.scaled_dot_product_attention(
180+
q_bhsd, k_bhsd, v_bhsd, is_causal=self.is_causal)
174181
output = output_bhsd.transpose(1, 2).contiguous()
175182

176183
if not self.grad:
@@ -179,4 +186,3 @@ def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, dO: tor
179186
loss = output.sum()
180187
loss.backward()
181188
return output, Q.grad, K.grad, V.grad
182-

benchmarks/flash_decode/gqa_decode.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, batch, heads, groups, seq_len_kv, dim, dtype):
1919

2020
@property
2121
def total_flops(self):
22-
flops_per_matmul = 2.0 * self.batch * self.heads * self.seq_len_kv * self.dim
22+
flops_per_matmul = 2.0 * self.batch * self.heads * self.seq_len_kv * self.dim
2323
flops = flops_per_matmul * 2
2424
return flops
2525

@@ -28,26 +28,29 @@ def total_memory(self):
2828
# Q: batch * 1 * heads * dim
2929
# K, V: batch * seq_len_kv * heads_kv * dim
3030
# Output: batch * 1 * heads * dim
31-
return 2 * self.batch * self.dim * self.dtype.itemsize * (self.heads + self.groups * self.seq_len_kv)
31+
return 2 * self.batch * self.dim * self.dtype.itemsize * (
32+
self.heads + self.groups * self.seq_len_kv)
3233

3334
def gen_inputs(self):
34-
Q = torch.randn(
35-
self.batch, self.heads, self.dim, device='cuda', dtype=self.dtype)
35+
Q = torch.randn(self.batch, self.heads, self.dim, device='cuda', dtype=self.dtype)
3636
K = torch.randn(
3737
self.batch, self.seq_len_kv, self.groups, self.dim, device='cuda', dtype=self.dtype)
3838
V = torch.randn(
3939
self.batch, self.seq_len_kv, self.groups, self.dim, device='cuda', dtype=self.dtype)
40-
mask = torch.randint(0, 2, (self.batch, self.seq_len_kv, self.groups), device='cuda', dtype=torch.uint8)
40+
mask = torch.randint(
41+
0, 2, (self.batch, self.seq_len_kv, self.groups), device='cuda', dtype=torch.uint8)
4142
return Q, K, V, mask
4243

4344
def ref_program(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor):
44-
q_bhsd = Q.unsqueeze(1).transpose(1, 2) # [B, H, 1, D]
45-
k_bhsd = K.transpose(1, 2) # [B, H, S_kv, D]
46-
v_bhsd = V.transpose(1, 2) # [B, H, S_kv, D]
45+
q_bhsd = Q.unsqueeze(1).transpose(1, 2) # [B, H, 1, D]
46+
k_bhsd = K.transpose(1, 2) # [B, H, S_kv, D]
47+
v_bhsd = V.transpose(1, 2) # [B, H, S_kv, D]
4748
mask = mask.to(torch.bool).transpose(1, 2).unsqueeze(2) # [B, G, 1, S_kv]
48-
mask = mask.expand(self.batch, self.groups, self.heads // self.groups, self.seq_len_kv).reshape(self.batch, self.heads, self.seq_len_kv).unsqueeze(2)
49+
mask = mask.expand(self.batch, self.groups, self.heads // self.groups,
50+
self.seq_len_kv).reshape(self.batch, self.heads,
51+
self.seq_len_kv).unsqueeze(2)
4952
with sdpa_kernel(backends=[SDPBackend.MATH]):
50-
output_bhsd = F.scaled_dot_product_attention(q_bhsd, k_bhsd, v_bhsd, attn_mask=mask, enable_gqa=True)
53+
output_bhsd = F.scaled_dot_product_attention(
54+
q_bhsd, k_bhsd, v_bhsd, attn_mask=mask, enable_gqa=True)
5155
output = output_bhsd.transpose(1, 2).squeeze(1).contiguous()
52-
return output
53-
56+
return output

0 commit comments

Comments
 (0)