Skip to content

Commit 757b564

Browse files
authored
Add lint to make sure examples and tests use device=DEVICE (#929)
1 parent dbf666e commit 757b564

File tree

3 files changed

+106
-2
lines changed

3 files changed

+106
-2
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ repos:
3535
entry: python scripts/lint_examples_main.py
3636
language: system
3737
files: ^examples/.*\.py$
38+
- id: check-no-hardcoded-device
39+
name: disallow hard-coded device kwarg outside of DEVICE
40+
entry: python scripts/lint_no_hardcoded_device.py
41+
language: system
42+
files: ^(?:examples|test)/.*\.py$
3843

3944
- repo: https://github.com/codespell-project/codespell
4045
rev: v2.4.1
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import pathlib
5+
import sys
6+
from typing import Iterable
7+
8+
ALLOWED_NAME = "DEVICE"
9+
10+
11+
class DeviceKwargVisitor(ast.NodeVisitor):
12+
def __init__(self, filename: str) -> None:
13+
self.filename = filename
14+
self.errors: list[tuple[int, int, str]] = []
15+
16+
def visit_Call(self, node: ast.Call) -> None:
17+
for kw in node.keywords or ():
18+
if kw.arg != "device":
19+
continue
20+
21+
# Only disallow string literals, e.g., device="cuda" or device='cpu'
22+
if isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str):
23+
self.errors.append(
24+
(
25+
getattr(kw, "lineno", node.lineno),
26+
getattr(kw, "col_offset", node.col_offset),
27+
"device must not be a string literal like 'cuda'; use DEVICE",
28+
)
29+
)
30+
31+
# Continue walking children
32+
self.generic_visit(node)
33+
34+
35+
def iter_python_files(paths: Iterable[str]) -> Iterable[pathlib.Path]:
36+
for p in paths:
37+
path = pathlib.Path(p)
38+
if path.is_dir():
39+
yield from path.rglob("*.py")
40+
elif path.suffix == ".py" and path.exists():
41+
yield path
42+
43+
44+
def should_check(path: pathlib.Path) -> bool:
45+
# Only check files under test/ or examples/
46+
try:
47+
parts = path.resolve().parts
48+
except FileNotFoundError:
49+
parts = path.parts
50+
# find directory names in the path
51+
return "test" in parts or "examples" in parts
52+
53+
54+
def check_file(path: pathlib.Path) -> list[tuple[int, int, str]]:
55+
try:
56+
source = path.read_text(encoding="utf-8")
57+
except UnicodeDecodeError:
58+
return []
59+
60+
try:
61+
tree = ast.parse(source, filename=str(path))
62+
except SyntaxError:
63+
# let other hooks catch syntax errors
64+
return []
65+
66+
visitor = DeviceKwargVisitor(str(path))
67+
visitor.visit(tree)
68+
69+
# Allow inline opt-out using marker on the same line: @ignore-device-lint
70+
lines = source.splitlines()
71+
filtered_errors: list[tuple[int, int, str]] = []
72+
for lineno, col, msg in visitor.errors:
73+
if 1 <= lineno <= len(lines):
74+
if "@ignore-device-lint" in lines[lineno - 1]:
75+
continue
76+
filtered_errors.append((lineno, col, msg))
77+
return filtered_errors
78+
79+
80+
def main(argv: list[str]) -> int:
81+
if len(argv) == 0:
82+
# pre-commit will pass file list; if not, scan the default dirs
83+
candidates = list(iter_python_files(["examples", "test"]))
84+
else:
85+
candidates = [pathlib.Path(a) for a in argv]
86+
87+
errors_found = 0
88+
for path in candidates:
89+
if not should_check(path):
90+
continue
91+
for lineno, col, msg in check_file(path):
92+
print(f"{path}:{lineno}:{col}: {msg}")
93+
errors_found += 1
94+
95+
return 1 if errors_found else 0
96+
97+
98+
if __name__ == "__main__":
99+
raise SystemExit(main(sys.argv[1:]))

test/test_type_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def use_device_properties(x: torch.Tensor) -> torch.Tensor:
114114
out[idx] = x[idx]
115115
return out
116116

117-
x = torch.ones([128], device="cuda")
117+
x = torch.ones([128], device="cuda") # @ignore-device-lint
118118
output = type_propagation_report(use_device_properties, x)
119119
self.assertExpectedJournal(output)
120120

@@ -129,7 +129,7 @@ def use_unsupported_property(x: torch.Tensor) -> torch.Tensor:
129129
x[i] = unsupported
130130
return x
131131

132-
x = torch.ones([16], device="cuda")
132+
x = torch.ones([16], device="cuda") # @ignore-device-lint
133133
with self.assertRaisesRegex(
134134
exc.TypeInferenceError,
135135
r"Attribute 'total_memory' is not supported on .*test_type_propagation.py",

0 commit comments

Comments
 (0)