Skip to content

ExecuTorch CoreML Delegate Segfault for aten::where (1-input) and aten::nonzero_numpyΒ #17537

@0xShug0

Description

@0xShug0

πŸ› Describe the bug

ExecuTorch CoreML Delegate Segfault for aten::where (1-input) and aten::nonzero_numpy

Summary

When using the ExecuTorch CoreML delegate, models containing:

  • aten::where (single-input form: where(x))
  • aten::nonzero_numpy

segfault at runtime during method.execute().

Export and lowering both succeed. The crash happens only during runtime
execution and terminates the process with exit code -11 (SIGSEGV).

...
[ETCoreMLModelManager.mm:478] Cache Miss: Compiled Model with identifier=executorch_d9f1ea99-04db-4da0-b04b-44b21aa2d7cb_all was not found in the models cache.
[ETCoreMLModelManager.mm:457] The model in the pte file is not pre-compiled.  Compiling with a 5 min timeout.
[ETCoreMLModelManager.mm:490] Successfully got compiled model with identifier=executorch_d9f1ea99-04db-4da0-b04b-44b21aa2d7cb_all.  Transferring ownership to assetManager.
zsh: segmentation fault  python executorch_crash_op_where_coreml.py

Environment

  • PyTorch 2.10.0
  • executorch 1.1.0
  • coremltools 9.0
  • Mac Mini Pro M4 macOS (Apple Silicon)
  • Python 3.11

Context

This issue was discovered while running large-scale automated operator testing using an open-source backend parity framework I maintain (opdiff). The framework exports and executes individual PyTorch aten ops/modules/models across multiple backends (Torch eager, ONNX, CoreML, ExecuTorch, etc.).

The framework runs hundreds of single-op models across multiple configurations (precision, decomposition settings, backend targets) and records per-op execution status and parity results. The crash described here was consistently reproduced across scaled runs.

The full self-contained reproduction script is available in the repository:

executorch_crash_op_where_coreml.py.


Minimal Repro Cases

Crash

class Where1Input(nn.Module):
    def forward(self, x):
        return torch.ops.aten.where.default(x)
class NonzeroNumpy(nn.Module):
    def forward(self, x):
        return torch.ops.aten.nonzero_numpy.default(x)

Both export and lower successfully, but method.execute() crashes with
SIGSEGV.


Exported Graph Pattern (where / nonzero_numpy)

Both operators lower to a graph equivalent to:

  1. nonzero = aten.nonzero(x) β†’ i64[u0, 2]
  2. slice columns β†’ i64[u0, 1]
  3. squeeze β†’ i64[u0]
  4. Return dynamic 1D int64 tensor(s)

So they ultimately return dynamic-length 1D int64 tensors of shape
i64[u0].


Debugging / Narrowing Down Root Cause

We progressively simplified the graph to isolate the trigger.

Works

class Nonzero(nn.Module):
    def forward(self, x):
        return torch.ops.aten.nonzero.default(x)

Return shape: i64[u0, 2]


class NonzeroSplitNoSqueeze(nn.Module):
    def forward(self, x):
        nz = torch.ops.aten.nonzero.default(x)
        r = torch.ops.aten.slice.Tensor(nz, 1, 0, 1)
        return r

Return shape: i64[u0, 1]


Crashes

class NonzeroSplitReturnRow(nn.Module):
    def forward(self, x):
        nz = torch.ops.aten.nonzero.default(x)
        r = torch.ops.aten.slice.Tensor(nz, 1, 0, 1)
        r = torch.ops.aten.squeeze.dims(r, [1])
        return r

Return shape: i64[u0]

The only difference between pass and crash is:

i64[u0,1] β†’ squeeze β†’ i64[u0]

This indicates that producing a dynamic-length 1D int64 tensor triggers
the crash.


Additional Confirmation (Non-nonzero Path)

class IndexBoolUnsqueezeSqueeze(nn.Module):
    def forward(self, x):
        m = torch.ops.aten.gt.Scalar(x, 0.0)
        y = torch.ops.aten.index.Tensor(x, [m])
        yi = torch.ops.aten._to_copy.default(y, dtype=torch.int64)
        yi2 = torch.ops.aten.unsqueeze.default(yi, 0)
        return torch.ops.aten.squeeze.dims(yi2, [0])

Return shape: i64[u0] β†’ Also crashes.


Float Version (Does NOT Crash)

class IndexBoolUnsqueezeSqueezeF32(nn.Module):
    def forward(self, x):
        m = torch.ops.aten.gt.Scalar(x, 0.0)
        y = torch.ops.aten.index.Tensor(x, [m])
        y2 = torch.ops.aten.unsqueeze.default(y, 0)
        return torch.ops.aten.squeeze.dims(y2, [0])

Return shape: f32[u0] β†’ Works.


Observed Pattern

Output Shape Dtype Result
i64[u0,2] int64 Works
i64[u0,1] int64 Works
i64[u0] int64 Crash
f32[u0] float32 Works

The crash is strongly correlated with producing a dynamic-length 1D
int64 tensor (i64[u0]), typically via aten.squeeze.dims.


Expected Behavior

ExecuTorch CoreML delegate should either:

  • Execute successfully, or
  • Raise a Python-visible error if unsupported

It should not terminate the process with a segmentation fault.


Impact

Any model using:

  • aten::where (1-input form),
  • aten::nonzero_numpy, or
  • any pattern producing dynamic 1D int64 tensors,

is vulnerable to runtime crashes when executed through the ExecuTorch
CoreML delegate.

Versions

Collecting environment information...
PyTorch version: 2.10.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.7.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.4.4.1)
CMake version: version 3.31.10
Libc version: N/A

Python version: 3.11.13 (main, Jun 3 2025, 18:38:25) [Clang 17.0.0 (clang-1700.0.13.3)] (64-bit runtime)
Python platform: macOS-15.7.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Apple M4 Pro

Versions of relevant libraries:
[pip3] executorch==1.1.0
[pip3] numpy==2.4.2
[pip3] onnx==1.20.1
[pip3] onnx-ir==0.1.16
[pip3] onnxruntime==1.24.1
[pip3] onnxscript==0.6.2
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.10.0
[pip3] torchao==0.15.0
[conda] numpy 2.1.3 py313h7c57ca2_0
[conda] numpy-base 2.1.3 py313hb98e858_0
[conda] numpydoc 1.2 pyhd3eb1b0_0
[conda] tbb 2021.8.0 h48ca7d4_0

cc @kimishpatel @YifanShenSZ @cymbalrush @metascroy

Metadata

Metadata

Assignees

Labels

module: coremlIssues related to Apple's Core ML delegation and code under backends/apple/coreml/

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions