-
Notifications
You must be signed in to change notification settings - Fork 847
Description
🐛 Describe the bug
ExecuTorch Runtime Segfault for aten::prelu With Constant In-Module Weight When Using run_decompositions({})
Summary
ExecuTorch crashes (SIGSEGV) at runtime inference (method.execute()) when exporting and lowering a model that contains aten::prelu where the PReLU weight is embedded as a constant inside the module (for example, via register_buffer), and the export pipeline explicitly uses an empty decomposition table via ExportedProgram.run_decompositions({}). The backend is XNNPACK CPU.
Export and lowering both succeed. The crash happens only during runtime execution and terminates the Python process.
When the same PReLU weight is provided as a tensor input (non-const), the pipeline completes and inference works.
=== case_prelu_const_w ===
run_decompositions: {}
1. export: torch.export.export OK
2. export: graph
class GraphModule(torch.nn.Module):
def forward(self, b_w: "f32[3]", x: "f32[2, 3, 3, 3]"):
prelu: "f32[2, 3, 3, 3]" = torch.ops.aten.prelu.default(x, b_w); x = b_w = None
return (prelu,)
3. lower: to_executorch OK
4. io: mkstemp OK (/var/folders/bl/4839w_nx4_5fnl7qn57cx2540000gn/T/prelu_f93pslf_.pte)
5. runtime: load_method OK
zsh: segmentation fault python executorch_crash_op_prelu.py
Environment
- PyTorch 2.10.0
- executorch 1.1.0
- coremltools 9.0
- Mac Mini Pro M4 macOS (Apple Silicon)
- Python 3.11
Context
This issue was discovered while running large-scale automated operator testing using an open-source backend parity framework I maintain (opdiff). The framework exports and executes individual PyTorch aten ops/modules/models across multiple backends (Torch eager, ONNX, CoreML, ExecuTorch, etc.).
The framework runs hundreds of single-op models across multiple configurations (precision, decomposition settings, backend targets) and records per-op execution status and parity results. The crash described here was consistently reproduced across scaled runs.
The full self-contained reproduction script is available in the repository:
Minimal Repro Script
Copy and paste this into your terminal and run it.
python - <<'PY'
import os
import tempfile
from pathlib import Path
import torch
import torch.nn as nn
import executorch
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.runtime import Runtime, Verification
import warnings
warnings.simplefilter("ignore", FutureWarning)
class ConstWPrelu(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("w", torch.ones(3, dtype=torch.float32))
def forward(self, x):
return torch.ops.aten.prelu.default(x, self.w)
torch.manual_seed(0)
x4d = torch.randn(2, 3, 3, 3, device="cpu", dtype=torch.float32)
print("=== case_prelu_const_w ===")
print("run_decompositions: {}")
exported = torch.export.export(
ConstWPrelu().eval().to(device="cpu", dtype=torch.float32),
args=(x4d,),
strict=True,
)
print("1. export: torch.export.export OK")
exported = exported.run_decompositions({})
print("2. export: graph")
exported.graph_module.print_readable()
edge_pm = to_edge_transform_and_lower(
exported,
partitioner=[XnnpackPartitioner()],
compile_config=None,
)
et_pm = edge_pm.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=True)
)
print("3. lower: to_executorch OK")
fd, p = tempfile.mkstemp(suffix=".pte", prefix="prelu_")
os.close(fd)
with open(p, "wb") as f:
et_pm.write_to_file(f)
print(f"4. io: mkstemp OK ({p})")
rt = Runtime.get()
program = rt.load_program(Path(p), verification=Verification.Minimal)
method = program.load_method("forward")
print("5. runtime: load_method OK")
y = method.execute((x4d,))
print("6. runtime: execute OK")
PY
Controlled Cases
See executorch_crash_op_prelu.py for full code
This issue is triggered by the combination of:
aten::prelu- constant in-module weight (buffer/attribute)
- explicit empty decomposition table:
run_decompositions({})
Crash (constant in-module weight + empty decomposition table)
class ConstWPrelu(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("w", torch.ones(3, dtype=torch.float32))
def forward(self, x):
return torch.ops.aten.prelu.default(x, self.w)Run with:
x4d = torch.randn(2, 3, 3, 3, device="cpu", dtype=torch.float32)
exported = torch.export.export(model, args=(x4d,), strict=True)
exported = exported.run_decompositions({}) # explicit empty table
...
y = method.execute((x4d,)) # segfault hereWorks (constant in-module weight with default decompositions)
class ConstWPrelu(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("w", torch.ones(3, dtype=torch.float32))
def forward(self, x):
return torch.ops.aten.prelu.default(x, self.w)Run with:
x4d = torch.randn(2, 3, 3, 3, device="cpu", dtype=torch.float32)
exported = torch.export.export(model, args=(x4d,), strict=True)
exported = exported.run_decompositions() # default decompositions
...
y = method.execute((x4d,)) # executes successfullyWorks (weight passed as tensor input + empty decomposition table)
class InputWPrelu(nn.Module):
def forward(self, x, w):
return torch.ops.aten.prelu.default(x, w)Run with:
x4d = torch.randn(2, 3, 3, 3, device="cpu", dtype=torch.float32)
wprelu = torch.ones(3, device="cpu", dtype=torch.float32)
exported = torch.export.export(model, args=(x4d, wprelu), strict=True)
exported = exported.run_decompositions({}) # OK
...
y = method.execute((x4d, wprelu)) # executes successfullyExpected Behavior
ExecuTorch should either:
- execute successfully, or
- raise a Python-visible error indicating unsupported lowering/runtime execution.
It should not terminate the process with a segmentation fault.
Notes
- The crash happens during runtime execution (
method.execute()), not during export or lowering. - If this pattern is not supported, a Python-visible error (or at least a safe runtime failure) would make it diagnosable and safe for automated testing pipelines.
Versions
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.7.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.4.4.1)
CMake version: version 3.31.10
Libc version: N/A
Python version: 3.11.13 (main, Jun 3 2025, 18:38:25) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-15.7.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Apple M4 Pro
Versions of relevant libraries:
[pip3] executorch==1.1.0a0+17adba1
[pip3] numpy==2.4.2
[pip3] onnx==1.20.1
[pip3] onnx-ir==0.1.16
[pip3] onnxruntime==1.24.1
[pip3] onnxscript==0.6.2
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0+git9338966da
[pip3] torchaudio==2.10.0
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.25.0
[conda] numpy 2.1.3 py313h7c57ca2_0
[conda] numpy-base 2.1.3 py313hb98e858_0
[conda] numpydoc 1.2 pyhd3eb1b0_0
[conda] tbb 2021.8.0 h48ca7d4_0