Skip to content

Commit 71a3345

Browse files
committed
[ET Device Support] Propagate device info from TensorSpec into serialized Tensor
Pull Request resolved: #18079 Propagate device information from `TensorSpec.device` (set by `PropagateDevicePass`) to the serialized `schema.Tensor` in the emitted PTE file, to make runtime further aware of it. ghstack-source-id: 352274088 @exported-using-ghexport Differential Revision: [D95899706](https://our.internmc.facebook.com/intern/diff/D95899706/)
1 parent 323e818 commit 71a3345

File tree

3 files changed

+137
-1
lines changed

3 files changed

+137
-1
lines changed

exir/emit/test/BUCK

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@ fbcode_target(_kind = runtime.python_test,
2626
"//executorch/exir:schema",
2727
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
2828
"//executorch/exir/backend:backend_api",
29+
"//executorch/exir/backend:compile_spec_schema",
30+
"//executorch/exir/backend:partitioner",
31+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
32+
"//executorch/exir/backend/test:backend_with_compiler_demo",
2933
"//executorch/exir/emit:lib",
3034
"//executorch/exir/passes:const_prop_pass",
3135
"//executorch/exir/passes:constant_prop_pass",
3236
"//executorch/exir/passes:init_mutable_pass",
37+
"//executorch/exir/passes:propagate_device_pass",
3338
"//executorch/exir/tests:lib",
3439
"//executorch/exir/tests:models",
3540
"//executorch/extension/pybindings:portable_lib",

exir/emit/test/test_emit.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,3 +2518,124 @@ def forward(self):
25182518
for j in range(2):
25192519
expected_storage.append(j * 16 + i)
25202520
self.assertEqual([int(v) for v in storage_values], expected_storage)
2521+
2522+
def test_emit_device_info_propagated_to_serialized_tensor(self) -> None:
2523+
"""Verify that device info from PropagateDevicePass flows through
2524+
the emitter into ExtraTensorInfo.device_type on serialized tensors."""
2525+
from executorch.exir.backend.compile_spec_schema import CompileSpec
2526+
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
2527+
generate_pattern_op_partitions,
2528+
)
2529+
from executorch.exir.backend.partitioner import (
2530+
DelegationSpec,
2531+
Partitioner,
2532+
PartitionResult,
2533+
)
2534+
from executorch.exir.passes.propagate_device_pass import (
2535+
TARGET_DEVICE_COMPILE_SPEC_KEY,
2536+
)
2537+
from executorch.exir.backend.test.backend_with_compiler_demo import (
2538+
BackendWithCompilerDemo,
2539+
)
2540+
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
2541+
2542+
class AddSupport(OperatorSupportBase):
2543+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
2544+
return node.op == "call_function" and node.target in [
2545+
exir_ops.edge.aten.add.Tensor,
2546+
]
2547+
2548+
class DevicePartitioner(Partitioner):
2549+
def __init__(self):
2550+
super().__init__()
2551+
self.delegation_spec = DelegationSpec(
2552+
BackendWithCompilerDemo.__name__,
2553+
[
2554+
CompileSpec("max_value", bytes([4])),
2555+
CompileSpec(TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0"),
2556+
],
2557+
)
2558+
2559+
def partition(self, exported_program) -> PartitionResult:
2560+
partition_tags = {}
2561+
partition_list = generate_pattern_op_partitions(
2562+
exported_program.graph_module,
2563+
op_support=any_chain(AddSupport()),
2564+
)
2565+
for partition in partition_list:
2566+
for node in partition.nodes:
2567+
tag = f"tag{partition.id}"
2568+
node.meta["delegation_tag"] = tag
2569+
partition_tags[tag] = self.delegation_spec
2570+
return PartitionResult(
2571+
tagged_exported_program=exported_program,
2572+
partition_tags=partition_tags,
2573+
)
2574+
2575+
class Model(torch.nn.Module):
2576+
def forward(self, a, b):
2577+
return torch.add(a, b)
2578+
2579+
model = Model()
2580+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2581+
2582+
edge = to_edge(
2583+
export(model, inputs),
2584+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2585+
)
2586+
lowered = edge.to_backend(DevicePartitioner())
2587+
et_prog = lowered.to_executorch()
2588+
program = et_prog._emitter_output.program
2589+
2590+
plan = program.execution_plan[0]
2591+
self.assertGreater(len(plan.delegates), 0)
2592+
2593+
tensor_values = [
2594+
v.val for v in plan.values if isinstance(v.val, Tensor)
2595+
]
2596+
cuda_tensors = [
2597+
t
2598+
for t in tensor_values
2599+
if t.extra_tensor_info is not None
2600+
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
2601+
]
2602+
# add(a, b) produces 1 delegate output tensor that should be CUDA
2603+
self.assertEqual(
2604+
len(cuda_tensors),
2605+
1,
2606+
f"Expected exactly 1 CUDA tensor for delegated add, got {len(cuda_tensors)}",
2607+
)
2608+
2609+
def test_emit_cpu_tensors_no_extra_device_info(self) -> None:
2610+
"""When all tensors are on CPU (default), ExtraTensorInfo should NOT be
2611+
created solely for device info — it should remain None for activation tensors."""
2612+
2613+
class Model(torch.nn.Module):
2614+
def forward(self, a, b):
2615+
return torch.add(a, b)
2616+
2617+
model = Model()
2618+
inputs = (torch.randn(2, 2), torch.randn(2, 2))
2619+
2620+
edge = to_edge(
2621+
export(model, inputs),
2622+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
2623+
)
2624+
et_prog = edge.to_executorch()
2625+
program = et_prog._emitter_output.program
2626+
2627+
plan = program.execution_plan[0]
2628+
tensor_values = [
2629+
v.val for v in plan.values if isinstance(v.val, Tensor)
2630+
]
2631+
cuda_tensors = [
2632+
t
2633+
for t in tensor_values
2634+
if t.extra_tensor_info is not None
2635+
and t.extra_tensor_info.device_type == schema.DeviceType.CUDA
2636+
]
2637+
self.assertEqual(
2638+
len(cuda_tensors),
2639+
0,
2640+
"No tensor should have CUDA device when model runs entirely on CPU",
2641+
)

exir/tensor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,16 @@ def to_list(
365365
tensor_size = to_list(spec.shape)
366366
tensor_dim_order = to_list(spec.dim_order)
367367

368+
extra_tensor_info = spec.extra_tensor_info
369+
# Propagate device from TensorSpec into ExtraTensorInfo for serialization.
370+
if spec.device != schema.DeviceType.CPU:
371+
if extra_tensor_info is None:
372+
extra_tensor_info = schema.ExtraTensorInfo(
373+
device_type=spec.device,
374+
)
375+
else:
376+
extra_tensor_info.device_type = spec.device
377+
368378
flatbuffer_tensor = schema.Tensor(
369379
scalar_type=scalar_type_enum(spec.scalar_type),
370380
# The runtime currently only supports tensors with offsets of zero.
@@ -376,7 +386,7 @@ def to_list(
376386
allocation_info=allocation_info,
377387
layout=layout_enum(spec.layout),
378388
shape_dynamism=spec.shape_dynamism,
379-
extra_tensor_info=spec.extra_tensor_info,
389+
extra_tensor_info=extra_tensor_info,
380390
)
381391
return flatbuffer_tensor
382392

0 commit comments

Comments
 (0)