@@ -2518,3 +2518,124 @@ def forward(self):
25182518 for j in range (2 ):
25192519 expected_storage .append (j * 16 + i )
25202520 self .assertEqual ([int (v ) for v in storage_values ], expected_storage )
2521+
2522+ def test_emit_device_info_propagated_to_serialized_tensor (self ) -> None :
2523+ """Verify that device info from PropagateDevicePass flows through
2524+ the emitter into ExtraTensorInfo.device_type on serialized tensors."""
2525+ from executorch .exir .backend .compile_spec_schema import CompileSpec
2526+ from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
2527+ generate_pattern_op_partitions ,
2528+ )
2529+ from executorch .exir .backend .partitioner import (
2530+ DelegationSpec ,
2531+ Partitioner ,
2532+ PartitionResult ,
2533+ )
2534+ from executorch .exir .passes .propagate_device_pass import (
2535+ TARGET_DEVICE_COMPILE_SPEC_KEY ,
2536+ )
2537+ from executorch .exir .backend .test .backend_with_compiler_demo import (
2538+ BackendWithCompilerDemo ,
2539+ )
2540+ from torch .fx .passes .operator_support import any_chain , OperatorSupportBase
2541+
2542+ class AddSupport (OperatorSupportBase ):
2543+ def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
2544+ return node .op == "call_function" and node .target in [
2545+ exir_ops .edge .aten .add .Tensor ,
2546+ ]
2547+
2548+ class DevicePartitioner (Partitioner ):
2549+ def __init__ (self ):
2550+ super ().__init__ ()
2551+ self .delegation_spec = DelegationSpec (
2552+ BackendWithCompilerDemo .__name__ ,
2553+ [
2554+ CompileSpec ("max_value" , bytes ([4 ])),
2555+ CompileSpec (TARGET_DEVICE_COMPILE_SPEC_KEY , b"cuda:0" ),
2556+ ],
2557+ )
2558+
2559+ def partition (self , exported_program ) -> PartitionResult :
2560+ partition_tags = {}
2561+ partition_list = generate_pattern_op_partitions (
2562+ exported_program .graph_module ,
2563+ op_support = any_chain (AddSupport ()),
2564+ )
2565+ for partition in partition_list :
2566+ for node in partition .nodes :
2567+ tag = f"tag{ partition .id } "
2568+ node .meta ["delegation_tag" ] = tag
2569+ partition_tags [tag ] = self .delegation_spec
2570+ return PartitionResult (
2571+ tagged_exported_program = exported_program ,
2572+ partition_tags = partition_tags ,
2573+ )
2574+
2575+ class Model (torch .nn .Module ):
2576+ def forward (self , a , b ):
2577+ return torch .add (a , b )
2578+
2579+ model = Model ()
2580+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2581+
2582+ edge = to_edge (
2583+ export (model , inputs ),
2584+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2585+ )
2586+ lowered = edge .to_backend (DevicePartitioner ())
2587+ et_prog = lowered .to_executorch ()
2588+ program = et_prog ._emitter_output .program
2589+
2590+ plan = program .execution_plan [0 ]
2591+ self .assertGreater (len (plan .delegates ), 0 )
2592+
2593+ tensor_values = [
2594+ v .val for v in plan .values if isinstance (v .val , Tensor )
2595+ ]
2596+ cuda_tensors = [
2597+ t
2598+ for t in tensor_values
2599+ if t .extra_tensor_info is not None
2600+ and t .extra_tensor_info .device_type == schema .DeviceType .CUDA
2601+ ]
2602+ # add(a, b) produces 1 delegate output tensor that should be CUDA
2603+ self .assertEqual (
2604+ len (cuda_tensors ),
2605+ 1 ,
2606+ f"Expected exactly 1 CUDA tensor for delegated add, got { len (cuda_tensors )} " ,
2607+ )
2608+
2609+ def test_emit_cpu_tensors_no_extra_device_info (self ) -> None :
2610+ """When all tensors are on CPU (default), ExtraTensorInfo should NOT be
2611+ created solely for device info — it should remain None for activation tensors."""
2612+
2613+ class Model (torch .nn .Module ):
2614+ def forward (self , a , b ):
2615+ return torch .add (a , b )
2616+
2617+ model = Model ()
2618+ inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ))
2619+
2620+ edge = to_edge (
2621+ export (model , inputs ),
2622+ compile_config = EdgeCompileConfig (_check_ir_validity = False ),
2623+ )
2624+ et_prog = edge .to_executorch ()
2625+ program = et_prog ._emitter_output .program
2626+
2627+ plan = program .execution_plan [0 ]
2628+ tensor_values = [
2629+ v .val for v in plan .values if isinstance (v .val , Tensor )
2630+ ]
2631+ cuda_tensors = [
2632+ t
2633+ for t in tensor_values
2634+ if t .extra_tensor_info is not None
2635+ and t .extra_tensor_info .device_type == schema .DeviceType .CUDA
2636+ ]
2637+ self .assertEqual (
2638+ len (cuda_tensors ),
2639+ 0 ,
2640+ "No tensor should have CUDA device when model runs entirely on CPU" ,
2641+ )
0 commit comments