Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions exir/passes/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,17 @@ fbcode_target(_kind = runtime.python_library,
"//caffe2:torch",
],
)

fbcode_target(_kind = runtime.python_library,
name = "propagate_device_pass",
srcs = [
"propagate_device_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:delegate",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:schema",
"//executorch/exir:tensor",
],
)
163 changes: 163 additions & 0 deletions exir/passes/propagate_device_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
from typing import Optional

import executorch.exir.schema as schema

import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.tensor import TensorSpec
from torch.fx.passes.infra.pass_base import PassBase, PassResult

logger: logging.Logger = logging.getLogger(__name__)

# CompileSpec key convention for specifying the target device.
# Partitioners that target a specific device should include a CompileSpec entry
# with this key and a value encoding the device string (e.g., b"cuda:0").
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"

# Mapping from torch.device type strings to schema.DeviceType.
_DEVICE_STR_TO_ET_DEVICE: dict[str, schema.DeviceType] = {
"cpu": schema.DeviceType.CPU,
"cuda": schema.DeviceType.CUDA,
}


def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
"""
Parse a target_device CompileSpec value (e.g., b"cuda:0") into
(DeviceType, device_index).
"""
device_str = value.decode("utf-8")
torch_device = torch.device(device_str)
device_type = _DEVICE_STR_TO_ET_DEVICE.get(torch_device.type, schema.DeviceType.CPU)
device_index = torch_device.index if torch_device.index is not None else 0
return device_type, device_index


def _get_lowered_module(
graph_module: torch.fx.GraphModule,
delegate_call_node: torch.fx.Node,
) -> Optional[LoweredBackendModule]:
"""
Given an executorch_call_delegate node, retrieve the associated
LoweredBackendModule from the graph module.
The first argument to executorch_call_delegate is a get_attr node
whose target names the LoweredBackendModule attribute.
"""
if len(delegate_call_node.args) < 1:
return None
lowered_node = delegate_call_node.args[0]
if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr":
return None
lowered_module = getattr(graph_module, lowered_node.target, None)
if isinstance(lowered_module, LoweredBackendModule):
return lowered_module
return None


def _get_target_device_from_compile_specs(
lowered_module: LoweredBackendModule,
) -> Optional[tuple[schema.DeviceType, int]]:
"""
Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return
the corresponding (DeviceType, device_index), or None if not found.
"""
for spec in lowered_module.compile_specs:
if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY:
return _parse_device_spec_value(spec.value)
return None


def _set_device_on_spec(
spec: TensorSpec,
device_type: schema.DeviceType,
) -> None:
"""Set the device attribute on a TensorSpec."""
spec.device = device_type


class PropagateDevicePass(PassBase):
"""
After to_backend, walk the graph and set device metadata on TensorSpecs
based on partitioner-assigned delegation info.

Rules:
1. Delegated nodes: Output tensors of a delegate call are marked with the
target device derived from the delegate's CompileSpec (key="target_device").
2. Non-delegated nodes: Remain on CPU (default).
3. Getitem nodes that extract from a delegate call inherit the device from
the delegate call's output spec at the corresponding index.
"""

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

Check warning on line 101 in exir/passes/propagate_device_pass.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 C901

'PropagateDevicePass.call' is too complex (15) See https://www.flake8rules.com/rules/C901.html.
changed = False
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == executorch_call_delegate:
lowered_module = _get_lowered_module(graph_module, node)
if lowered_module is None:
continue

result = _get_target_device_from_compile_specs(lowered_module)
Copy link
Contributor

@digantdesai digantdesai Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This effectively assumes that we know the device 'name' AoT. In theory, we can have a multi-device delegate then the runtime might interpret this name differently and that can cause some confusion i.e cuda:0 device on Metal.

I am not sure about using generic names like 'gpu' but also not sure about following PyTorch's eager/jit style naming convention where you won't switch devices underneath.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May I have your suggestions on the executorch device name?

Currently we set up the device name AOT and intentionally decouple dour device attribute with pytorch/pytorch device concept; we created a enum in the etensor schema for all devices we are supporting right now. In this way we can support as much as device as we want.

For the situaton you mentioned, if other backend like vulken need its own gpu device, they should add a new one to the enum. We should avoid using generic names like 'gpu'.

if result is None:
continue

target_device_type, _device_index = result

# Mark all output TensorSpecs of this delegate call node
specs = node.meta.get("spec")
if specs is None:
continue

if isinstance(specs, TensorSpec):
_set_device_on_spec(specs, target_device_type)
changed = True
elif isinstance(specs, (tuple, list)):
for s in specs:
if isinstance(s, TensorSpec):
_set_device_on_spec(s, target_device_type)
changed = True

logger.debug(
"PropagateDevicePass: set device=%s on delegate node %s "
"(backend=%s)",
target_device_type,
node.name,
lowered_module.backend_id,
)

# Second pass: propagate device through getitem nodes that extract
# individual outputs from a delegate call.
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target.__name__ == "getitem":
source_node = node.args[0]
if (
isinstance(source_node, torch.fx.Node)
and source_node.op == "call_function"
and source_node.target == executorch_call_delegate
):
spec = node.meta.get("spec")
source_specs = source_node.meta.get("spec")
idx = node.args[1]
if (
spec is not None
and isinstance(spec, TensorSpec)
and source_specs is not None
and isinstance(source_specs, (tuple, list))
and isinstance(idx, int)
and idx < len(source_specs)
):
source_spec = source_specs[idx]
if isinstance(source_spec, TensorSpec):
_set_device_on_spec(spec, source_spec.device)
changed = True

return PassResult(graph_module, changed)
1 change: 1 addition & 0 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None:
"mem_offset",
"dtype", # property
"extra_tensor_info", # property
"device",
]

# Make sure _self_fields and _base_fields are disjoint
Expand Down
1 change: 1 addition & 0 deletions exir/program/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ fbcode_target(_kind = runtime.python_library,
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
"//executorch/exir/passes:lib",
"//executorch/exir/passes:normalize_view_copy_base_pass",
"//executorch/exir/passes:propagate_device_pass",
"//executorch/exir/passes:remove_graph_asserts_pass",
"//executorch/exir/passes:remove_mixed_type_operators",
"//executorch/exir/passes:replace_aten_with_edge_pass",
Expand Down
5 changes: 5 additions & 0 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
from executorch.exir.passes.reinplace import reinplace_pass
from executorch.exir.passes.remove_graph_asserts_pass import (
Expand Down Expand Up @@ -848,6 +849,10 @@ def edge_to_executorch_passes(
# there exists an unbacked symint operation.
*config.passes,
SpecPropPass(),
# Propagate device metadata (e.g., CUDA) from delegate CompileSpecs onto
# TensorSpecs. Must run after SpecPropPass so specs are freshly created
# with correct shapes.
PropagateDevicePass(),
EdgeToBackendOpsPass(),
RemoveGraphAssertsPass(),
] + pre_memory_planning_passes(config, name)
Expand Down
3 changes: 3 additions & 0 deletions exir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def __init__(
self.init_mem_planning_fields()
self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
self.extra_tensor_info = extra_tensor_info
# device type will be only updated during PropagateDevicePass.
self.device: schema.DeviceType = schema.DeviceType.CPU

@property
def allocated_memory(self) -> int:
Expand Down Expand Up @@ -254,6 +256,7 @@ def __repr__(self) -> str:
+ f", is_sparse={self.is_sparse}"
+ f", shape_dynamism={self.shape_dynamism}"
+ f", const={self.const}, requires_grad={self.requires_grad}"
+ f", device={self.device.name}"
+ ")"
)

Expand Down
20 changes: 20 additions & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -516,3 +516,23 @@ python_unittest(
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
],
)

python_unittest(
name = "propagate_device_pass",
srcs = [
"test_propagate_device_pass.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:schema",
"//executorch/exir:tensor",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/dialects:lib",
"//executorch/exir/passes:propagate_device_pass",
],
)
Loading
Loading