Skip to content

Commit 558cfb0

Browse files
authored
Merge branch 'main' into ph-mypy-ops-root-dir
2 parents fdf2dc1 + a4e7475 commit 558cfb0

File tree

213 files changed

+1769
-1664
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

213 files changed

+1769
-1664
lines changed

.ci/scripts/setup-windows-msvc.ps1

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
conda create --yes --quiet -n et python=3.12
2+
conda activate et
3+
4+
# Install cmake
5+
conda install -y cmake
6+
7+
# Activate the VS environment - this is required for MSVC to work
8+
# There are a bunch of environment variables that it requires.
9+
# See https://learn.microsoft.com/en-us/cpp/build/building-on-the-command-line.
10+
& "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Launch-VsDevShell.ps1" -Arch amd64
11+
12+
# Install CI requirements
13+
pip install -r .ci/docker/requirements-ci.txt
14+
15+
# Create build directory
16+
$buildDir = "cmake-out-msvc"
17+
if (Test-Path -Path $buildDir) {
18+
Remove-Item -Path $buildDir -Recurse -Force
19+
}
20+
New-Item -Path $buildDir -ItemType Directory
21+
22+
# Configure CMake with MSVC (not ClangCL) and disable custom/quantized ops
23+
cmake -S . -B $buildDir `
24+
-DCMAKE_BUILD_TYPE=Release `
25+
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON `
26+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON `
27+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON `
28+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON `
29+
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON `
30+
-DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON `
31+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON `
32+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF `
33+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=OFF `
34+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=OFF `
35+
-DEXECUTORCH_BUILD_XNNPACK=ON `
36+
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON `
37+
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON
38+
39+
if ($LASTEXITCODE -ne 0) {
40+
Write-Host "CMake configuration failed. Exit code: $LASTEXITCODE."
41+
exit $LASTEXITCODE
42+
}
43+
44+
# Build with MSVC
45+
cmake --build $buildDir --config Release -j16
46+
47+
if ($LASTEXITCODE -ne 0) {
48+
Write-Host "Build failed. Exit code: $LASTEXITCODE."
49+
exit $LASTEXITCODE
50+
}
51+
52+
Write-Host "MSVC build completed successfully!"

.github/workflows/cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ jobs:
8989
9090
export-voxtral-cuda-artifact:
9191
name: export-voxtral-cuda-${{ matrix.quant.name }}
92+
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
93+
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request'
9294
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9395
permissions:
9496
id-token: write
@@ -166,6 +168,8 @@ jobs:
166168
167169
export-gemma3-cuda-artifact:
168170
name: export-gemma3-cuda-${{ matrix.quant.name }}
171+
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
172+
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request'
169173
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
170174
permissions:
171175
id-token: write

.github/workflows/windows-msvc.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Windows MSVC Build
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- release/*
8+
tags:
9+
- ciflow/trunk/*
10+
pull_request:
11+
paths:
12+
- .ci/docker/ci_commit_pins/pytorch.txt
13+
- .ci/scripts/**
14+
workflow_dispatch:
15+
16+
concurrency:
17+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
18+
cancel-in-progress: true
19+
20+
jobs:
21+
build-windows-msvc:
22+
name: build-windows-msvc
23+
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
24+
with:
25+
submodules: 'recursive'
26+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
27+
timeout: 60
28+
script: |
29+
conda init powershell
30+
powershell -Command "& {
31+
Set-PSDebug -Trace 1
32+
\$ErrorActionPreference = 'Stop'
33+
\$PSNativeCommandUseErrorActionPreference = \$true
34+
.ci/scripts/setup-windows-msvc.ps1
35+
}"

.mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ ignore_missing_imports = True
8383
[mypy-tosa_tools.*]
8484
ignore_missing_imports = True
8585

86+
[mypy-tosa_serializer]
87+
ignore_missing_imports = True
88+
89+
[mypy-tosa_serializer.*]
90+
ignore_missing_imports = True
91+
8692
[mypy-setuptools.*]
8793
ignore_missing_imports = True
8894

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
8989
from .remove_noop_pass import RemoveNoopPass # noqa
9090
from .replace_scalar_with_tensor_pass import ( # noqa
91-
ReplaceScalarWithTensorArgPassTOSABI,
92-
ReplaceScalarWithTensorArgPassTOSAMI,
91+
ReplaceScalarWithTensorByProfilePass,
9392
)
9493
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
9594
from .rewrite_matmul import RewriteMatmulPass # noqa

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import itertools
98
import operator
@@ -52,7 +51,7 @@ def _match_partition_to_node(
5251
raise RuntimeError(f"Cannot find an input node which matches, {node}.")
5352

5453
def call(self, graph_module: GraphModule) -> PassResult:
55-
matmul_partitions = get_source_partitions(
54+
matmul_partitions_map = get_source_partitions(
5655
graph_module.graph,
5756
[
5857
torch.matmul,
@@ -61,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6160
None,
6261
)
6362
matmul_partitions = list(
64-
itertools.chain.from_iterable(matmul_partitions.values())
63+
itertools.chain.from_iterable(matmul_partitions_map.values())
6564
)
6665
matmul_targets = {
6766
exir_ops.edge.aten.bmm.default,
@@ -89,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8988
# Create new dq-node before matmul
9089
dq_node = create_node(
9190
graph=graph_module.graph,
92-
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
91+
op_target=cast(EdgeOpOverload, input_node.target),
9392
)
9493
dq_node.args = (node, *input_node.args[1:])
9594
matmul_node.replace_input_with(node, dq_node)
@@ -110,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
110109
# Create q-node after matmul
111110
q_node = create_node(
112111
graph=graph_module.graph,
113-
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
112+
op_target=cast(EdgeOpOverload, partition_output.target),
114113
)
115114
matmul_node.replace_all_uses_with(q_node)
116115
q_node.args = (matmul_node, *partition_output.args[1:])

backends/arm/_passes/arm_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import traceback
98
from abc import abstractmethod

backends/arm/_passes/arm_pass_manager.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
9-
108

119
from collections import defaultdict
1210

@@ -89,8 +87,7 @@
8987
QuantizeOperatorArguments,
9088
RemoveNoopPass,
9189
ReplaceInfValues,
92-
ReplaceScalarWithTensorArgPassTOSABI,
93-
ReplaceScalarWithTensorArgPassTOSAMI,
90+
ReplaceScalarWithTensorByProfilePass,
9491
RetraceFoldedDtypesPass,
9592
RewriteConv2dPass,
9693
RewriteMatmulPass,
@@ -156,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
156153
with TosaLoweringContext(self.tosa_spec):
157154
return self(graph_module).graph_module
158155

159-
def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
156+
def _tosa_INT_pipeline(
157+
self, exported_program: ExportedProgram, graph_module: GraphModule
158+
) -> GraphModule:
160159
self.add_pass(AnnotateOutputDimOrderPass())
161160
self.add_pass(FuseQuantizedActivationPass())
162161
self.add_pass(RemoveGetItemPass())
163162
self.add_pass(ConvertSplitToSlicePass())
164163
self.add_pass(ConvertMmToBmmPass())
165-
self.add_pass(
166-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
167-
)
164+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
168165
self.add_pass(ConvertFullLikeToFullPass())
169166
self.add_pass(ConvertToClampPass())
170167
self.add_pass(ConvertMinMaxPass())
@@ -174,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
174171
self.add_pass(CastToInt32Pass())
175172

176173
self.add_pass(CastBoolToInt8Pass())
177-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
178175
self.add_pass(AnnotateDecomposedMatmulPass())
179176
self.add_pass(QuantizeOperatorArguments())
180177
self.add_pass(ConvertELUParamsPass())
@@ -194,7 +191,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194191
self.add_pass(ConvertExpandCopyToRepeatPass())
195192
self.add_pass(UnsqueezeBeforeRepeatPass())
196193
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
197-
self.add_pass(DecomposeSumPass())
198194
self.add_pass(DecomposeCumsumPass(exported_program))
199195
self.add_pass(Conv1dUnsqueezePass())
200196
self.add_pass(DecomposeMaxPool2DPass())
@@ -215,15 +211,18 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215211
self.add_pass(RewriteMatmulPass())
216212
self.add_pass(RewriteUpsamplePass())
217213
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
214+
self.add_pass(InsertRescaleInt32Pass())
215+
self.add_pass(DecomposeSumPass())
218216
self.add_pass(ToTosaMemoryFormatPass(exported_program))
219217
self.add_pass(RemoveNoopPass())
220218
self.add_pass(InsertRescalePass())
221-
self.add_pass(InsertRescaleInt32Pass())
222219

223220
self.validate_constraints_mandatory()
224-
return self._transform(exported_program.graph_module)
221+
return self._transform(graph_module)
225222

226-
def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
223+
def _tosa_FP_pipeline(
224+
self, exported_program: ExportedProgram, graph_module: GraphModule
225+
) -> GraphModule:
227226
self.add_pass(AnnotateOutputDimOrderPass())
228227
self.add_pass(DecomposeExpm1Pass())
229228
self.add_pass(DecomposeLogitPass())
@@ -244,7 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
244243
self.add_pass(DecomposeSinhPass())
245244
self.add_pass(DecomposeSignPass())
246245
self.add_pass(DecomposeDivTensorModePass())
247-
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
246+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
248247
self.add_pass(DecomposeEmbeddingPass())
249248
self.add_pass(FuseQuantizedActivationPass())
250249
self.add_pass(RemoveGetItemPass())
@@ -258,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
258257
self.add_pass(DecomposeLayerNormPass())
259258
self.add_pass(DecomposeBatchNormNoStatsPass())
260259
self.add_pass(DecomposeVarPass())
261-
self.add_pass(
262-
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
263-
)
260+
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
264261
self.add_pass(DecomposeNotEqualPass())
265262
self.add_pass(DecomposeDivPass())
266263
self.add_pass(DecomposeAddSubAlphaPass())
@@ -308,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
308305
self.add_pass(InsertRescalePass())
309306

310307
self.validate_constraints_mandatory()
311-
return self._transform(exported_program.graph_module)
308+
return self._transform(graph_module)
312309

313-
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
310+
def transform_to_backend_pipeline(
311+
self, exported_program: ExportedProgram, graph_module: GraphModule
312+
):
314313
"""Apply passes before transforming program to backend"""
315314
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
316-
return self._tosa_FP_pipeline(exported_program)
315+
return self._tosa_FP_pipeline(exported_program, graph_module)
317316
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
318-
return self._tosa_INT_pipeline(exported_program)
317+
return self._tosa_INT_pipeline(exported_program, graph_module)
319318
else:
320319
raise NotImplementedError(
321320
f"No pass pipeline implemented for {self.tosa_spec=}"
@@ -337,7 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
337336
self.add_pass(DecomposeAddmmPass())
338337
self.add_pass(DecomposeDivTensorModePass())
339338
self.add_pass(DecomposeAddSubAlphaPass())
340-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
339+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
341340
self.add_pass(ScalarsToAttributePass())
342341
self.add_pass(DecomposeGroupNormPass())
343342
self.add_pass(DecomposeLayerNormPass())
@@ -361,7 +360,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361360

362361
self.add_pass(ConvertMinMaxPass())
363362
self.add_pass(ReplaceInfValues())
364-
self.add_pass(DecomposeSumPass())
365363

366364
if not self.tosa_spec.is_U55_subset:
367365
# Uses where which is not supported on Ethos-U55

backends/arm/_passes/arm_pass_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
98

109
import traceback
1110
from inspect import isclass
@@ -14,8 +13,10 @@
1413
import torch
1514
import torch.fx
1615
from executorch.backends.arm.common.debug import get_node_debug_info
16+
from executorch.backends.arm.common.type import ensure_type
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1920

2021
from torch._export.utils import (
2122
get_buffer,
@@ -82,17 +83,18 @@ def get_param_tensor(
8283
elif is_lifted_tensor_constant(exp_prog, node):
8384
return get_lifted_tensor_constant(exp_prog, node)
8485
elif is_get_attr_node(node):
86+
target_node = ensure_type(str, node.target)
8587
# This is a hack to support both lifted and unlifted graph
8688
try:
87-
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
89+
return getattr(node.graph.owning_module, target_node)
8890
except AttributeError:
89-
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
91+
return getattr(exp_prog.graph_module, target_node)
9092
raise RuntimeError(f"unsupported param type, {node.op}.")
9193

9294

9395
def create_node(
9496
graph: torch.fx.Graph,
95-
op_target: OpOverload,
97+
op_target: OpOverload | EdgeOpOverload,
9698
args: tuple = (),
9799
kwargs: Optional[dict] = None,
98100
quantize: bool = False,

backends/arm/_passes/cast_int64_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import logging
98
from typing import Set, Type

0 commit comments

Comments
 (0)