@@ -413,20 +413,28 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
413413 xla::XlaComputation xla_computation =
414414 GetValueOrThrow (b.Build (/* remove_dynamic_dimensions=*/ false ));
415415
416- std::vector<torch::lazy::BackendDataPtr> parameters_data;
417- parameters_data.push_back (
416+ std::vector<XLATensorPtr> tensors{XLATensor::Create (
418417 torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
419- bridge::GetDefaultDevice ()->toString (), std::move (shape)));
418+ bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
419+ std::vector<std::vector<int64_t >> denormalized_tile_assignments;
420+ for (auto tensor : tensors) {
421+ auto sharding_spec = tensor->sharding_spec ();
422+ if (sharding_spec) {
423+ denormalized_tile_assignments.push_back (
424+ sharding_spec->sharding .GetDenormalizedTileAssignment ());
425+ }
426+ }
420427
421428 std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
422- instances.push_back ({std::move (xla_computation),
423- bridge::GetDefaultDevice ()->toString (),
424- {bridge::GetDefaultDevice ()->toString ()},
425- &shape,
426- /* should_wrap_parameter=*/ false ,
427- /* is_sharded=*/ true ,
428- /* allow_spmd_sharding_propagation_to_output=*/ true ,
429- /* parameters_data=*/ parameters_data});
429+ instances.push_back (
430+ {std::move (xla_computation),
431+ bridge::GetDefaultDevice ()->toString (),
432+ {bridge::GetDefaultDevice ()->toString ()},
433+ &shape,
434+ /* should_wrap_parameter=*/ false ,
435+ /* is_sharded=*/ true ,
436+ /* allow_spmd_sharding_propagation_to_output=*/ true ,
437+ /* denormalized_tile_assignments=*/ denormalized_tile_assignments});
430438
431439 std::vector<
432440 std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -437,9 +445,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
437445 " add" , std::move (computations[0 ]->move_computation ()));
438446
439447 // Prepare output sharding propagation, expect a sharded output placeholder.
440- std::vector<XLATensorPtr> tensors{XLATensor::Create (
441- torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
442- bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
443448 std::vector<torch::lazy::BackendDataPtr> data_placeholders;
444449 std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
445450 ShardingUtil::PrepareOutputShardingPropagation (
0 commit comments