Skip to content

Commit 29729c8

Browse files
authored
Enable mypy for scripts (#2710)
Enable mypy for scripts and fix the existing issues. Fixes #2709.
1 parent 5cb1ebd commit 29729c8

File tree

4 files changed

+14
-5
lines changed

4 files changed

+14
-5
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ repos:
137137
- --disable=fixme
138138
stages: [pre-commit, pre-push, manual]
139139

140+
- repo: https://github.com/pre-commit/mirrors-mypy
141+
rev: v1.13.0
142+
hooks:
143+
- id: mypy
144+
name: mypy for scripts
145+
stages: [pre-commit, pre-push, manual]
146+
files: '^scripts/.*'
140147

141148
exclude: |
142149
(?x)(

scripts/compare-ci-runs/compare_runs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def merge_triton_xetla_reports_data(config: str, triton_file: Path, xetla_file:
111111
return pd.DataFrame()
112112

113113

114-
def build_triton_benchmark_reports_path(directory: Path, report_name: str) -> str:
114+
def build_triton_benchmark_reports_path(directory: Path, report_name: str) -> Path:
115115
"""Construct the full file path for a given report name."""
116-
return os.path.join(directory, "benchmark-reports", f"{report_name}-report.csv")
116+
return directory / "benchmark-reports" / f"{report_name}-report.csv"
117117

118118

119119
def parse_triton_benchmark_data(config: str, directory: Path) -> pd.DataFrame:

scripts/pass_rate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ def find_stats(stats: List[ReportStats], name: str) -> ReportStats:
135135
raise ValueError(f'{name} not found')
136136

137137

138-
def parse_junit_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
138+
def parse_junit_reports(args: argparse.Namespace) -> List[ReportStats]:
139139
"""Parses junit report in the specified directory."""
140140
reports_path = pathlib.Path(args.reports)
141141
return [parse_report(report, args.skiplist_dir) for report in reports_path.glob('*.xml')]
142142

143143

144-
def parse_tutorials_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
144+
def parse_tutorials_reports(args: argparse.Namespace) -> List[ReportStats]:
145145
"""Parses tutorials reports in the specified directory."""
146146
reports_path = pathlib.Path(args.reports)
147147
stats = ReportStats(name='tutorials')
@@ -157,7 +157,7 @@ def parse_tutorials_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
157157
return [stats]
158158

159159

160-
def parse_reports(args: argparse.ArgumentParser) -> List[ReportStats]:
160+
def parse_reports(args: argparse.Namespace) -> List[ReportStats]:
161161
"""Parses all report in the specified directory."""
162162
return parse_junit_reports(args) + parse_tutorials_reports(args)
163163

scripts/run_tutorial.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def create_argument_parser() -> argparse.ArgumentParser:
4141
def run_tutorial(path: pathlib.Path):
4242
"""Runs """
4343
spec = importlib.util.spec_from_file_location('__main__', path)
44+
if not spec or not spec.loader:
45+
raise AssertionError(f'Failed to load module from {path}')
4446
module = importlib.util.module_from_spec(spec)
4547
# set __file__ to the absolute name, a workaround for 10i-experimental-block-pointer, which
4648
# uses dirname of its location to find 10-experimental-block-pointer.

0 commit comments

Comments
 (0)