@@ -49,15 +49,18 @@ TEST_F(XLAShardingTest, GetShardShape) {
4949 {0 , 1 },
5050 {2 , 3 },
5151 });
52- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
52+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
53+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
5354 auto sharding_spec =
5455 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
5556
5657 auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
5758 // For tiled sharding, each dimension should be halved
5859 EXPECT_EQ (shard_shape, std::vector<int64_t >({4 , 4 }));
5960
60- sharding_spec->sharding = xla::HloSharding::Replicate ().ToProto ();
61+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
62+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
63+ sharding_spec->sharding = sharding;
6164 shard_shape = ShardingUtil::GetShardShape (sharding_spec);
6265 // For replicated sharding, each dimension should be preserved
6366 EXPECT_EQ (shard_shape, std::vector<int64_t >({8 , 7 }));
@@ -73,7 +76,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7376 {0 , 1 },
7477 {2 , 3 },
7578 });
76- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
79+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
80+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
7781 auto sharding_spec =
7882 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
7983 auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
@@ -102,7 +106,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
102106 EXPECT_EQ (slice.step (), 1 );
103107 }
104108 }
105- sharding = xla::HloSharding::Replicate ().ToProto ();
109+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
110+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
106111 sharding_spec->sharding = sharding;
107112 shard_shape = ShardingUtil::GetShardShape (sharding_spec);
108113 replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices (
@@ -125,11 +130,12 @@ TEST_F(XLAShardingTest, ShardTensor) {
125130 at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
126131 xla::Shape tensor_shape =
127132 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
128- xla::OpSharding sharding =
133+ xla::OpSharding xla_sharding =
129134 xla::HloSharding::Tile1D (
130135 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ()),
131136 devices.size ())
132137 .ToProto ();
138+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
133139 auto sharding_spec =
134140 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
135141 auto shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -147,7 +153,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
147153 {0 , 1 , 2 , 3 },
148154 {4 , 5 , 6 , 7 },
149155 });
150- sharding = xla::HloSharding::Tile (mesh).ToProto ();
156+ xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
157+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
151158 sharding_spec =
152159 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
153160 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -159,15 +166,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
159166 // 3D tiled, the first dim is replicated and the last halved. The last shard
160167 // size should be smaller in dim=1 because it's not evenly divisible.
161168 xla::Array3D<int64_t > cube ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}});
162- sharding_spec->sharding = xla::HloSharding::Tile (cube).ToProto ();
169+ xla_sharding = xla::HloSharding::Tile (cube).ToProto ();
170+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
171+ sharding_spec->sharding = sharding;
163172 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
164173 /* padded=*/ false );
165174 EXPECT_EQ (shards.size (), 8 );
166175 EXPECT_EQ (shards[0 ].sizes (), c10::ArrayRef<long >({8 , 2 , 2 }));
167176 EXPECT_EQ (shards[7 ].sizes (), c10::ArrayRef<long >({8 , 1 , 2 }));
168177
169178 // Replicated, all shards should be identical.
170- sharding_spec->sharding = xla::HloSharding::Replicate ().ToProto ();
179+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
180+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
181+ sharding_spec->sharding = sharding;
171182 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
172183 /* padded=*/ false );
173184 EXPECT_EQ (shards.size (), 8 );
@@ -181,7 +192,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
181192 tensor_shape =
182193 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
183194 xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
184- sharding = xla::HloSharding::Tile (tesseract).ToProto ();
195+ xla_sharding = xla::HloSharding::Tile (tesseract).ToProto ();
196+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
185197 sharding_spec =
186198 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
187199 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -205,7 +217,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
205217 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
206218 xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
207219 hypercube.FillIota (0 );
208- sharding = xla::HloSharding::Tile (hypercube).ToProto ();
220+ xla_sharding = xla::HloSharding::Tile (hypercube).ToProto ();
221+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
209222 sharding_spec =
210223 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
211224 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -233,7 +246,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
233246 {4 , 5 , 0 , 1 },
234247 {6 , 7 , 2 , 3 },
235248 });
236- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
249+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
250+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
237251 auto sharding_spec =
238252 std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
239253 // For devices at the start of the mesh, all shards should have the same
@@ -250,7 +264,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
250264 {0 , 1 , 4 , 5 },
251265 {2 , 3 , 6 , 7 },
252266 });
253- sharding_spec->sharding = xla::HloSharding::Tile (mesh).ToProto ();
267+ xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
268+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
269+ sharding_spec->sharding = sharding;
254270 shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
255271 /* padded=*/ false );
256272 EXPECT_EQ (shards.size (), 4 );
@@ -277,7 +293,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
277293 {{7 }},
278294 });
279295
280- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
296+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
297+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
281298 auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
282299 sharding, global_shape, /* minibatch=*/ true );
283300 auto shards = ShardingUtil::ShardTensor (minibatch_tensor, sharding_spec,
@@ -291,17 +308,20 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
291308 auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
292309 xla::Shape tensor_shape =
293310 CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
294- XLATensor::ShardingSpec tiled_2d (xla::HloSharding::Tile ({
295- {0 , 1 , 2 , 3 },
296- {4 , 5 , 6 , 7 },
297- })
298- .ToProto (),
299- tensor_shape);
300- XLATensor::ShardingSpec tiled_3d (
301- xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto (),
302- tensor_shape);
303- XLATensor::ShardingSpec replicated (xla::HloSharding::Replicate ().ToProto (),
304- tensor_shape);
311+ auto xla_sharding = xla::HloSharding::Tile ({
312+ {0 , 1 , 2 , 3 },
313+ {4 , 5 , 6 , 7 },
314+ })
315+ .ToProto ();
316+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
317+ XLATensor::ShardingSpec tiled_2d (sharding, tensor_shape);
318+ xla_sharding =
319+ xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto ();
320+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
321+ XLATensor::ShardingSpec tiled_3d (sharding, tensor_shape);
322+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
323+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
324+ XLATensor::ShardingSpec replicated (sharding, tensor_shape);
305325 EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_2d));
306326 EXPECT_FALSE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_3d));
307327 EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (replicated, replicated));
@@ -322,12 +342,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
322342 std::vector<std::string> devices (3 );
323343 std::fill_n (devices.begin (), devices.size (),
324344 bridge::GetDefaultDevice ()->toString ());
345+ auto replicate_xla_sharding = xla::HloSharding::Replicate ().ToProto ();
346+ auto unknown_xla_sharding = xla::HloSharding::Unknown ().ToProto ();
347+ torch_xla::OpSharding replicate_sharding (replicate_xla_sharding,
348+ std::nullopt );
349+ torch_xla::OpSharding unknown_sharding (unknown_xla_sharding, std::nullopt );
325350 std::vector<XLATensor::ShardingSpecPtr> shardings = {
326351 nullptr ,
327- std::make_shared<XLATensor::ShardingSpec>(
328- xla::HloSharding::Replicate (). ToProto (), tensor_shape),
329- std::make_shared<XLATensor::ShardingSpec>(
330- xla::HloSharding::Unknown (). ToProto (), tensor_shape)};
352+ std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
353+ tensor_shape),
354+ std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
355+ tensor_shape)};
331356 std::vector<torch::lazy::BackendDataPtr> tensors_data =
332357 CreateTensorsData (tensors, shardings, devices);
333358
@@ -386,13 +411,21 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
386411 auto y = xla::Add (x, xla::ConstantR0<float >(&b, 3 ));
387412 xla::XlaComputation xla_computation =
388413 ConsumeValue (b.Build (/* remove_dynamic_dimensions=*/ false ));
414+
415+ std::vector<torch::lazy::BackendDataPtr> parameters_data;
416+ parameters_data.push_back (
417+ torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
418+ bridge::GetDefaultDevice ()->toString (), std::move (shape)));
419+
389420 std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
390421 instances.push_back ({std::move (xla_computation),
391422 bridge::GetDefaultDevice ()->toString (),
392423 {bridge::GetDefaultDevice ()->toString ()},
393424 &shape,
394425 /* should_wrap_parameter=*/ false ,
395- /* is_sharded=*/ true });
426+ /* is_sharded=*/ true ,
427+ /* allow_spmd_sharding_propagation_to_output=*/ true ,
428+ /* parameters_data=*/ parameters_data});
396429
397430 std::vector<
398431 std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -416,11 +449,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
416449 if (n_devices > 1 ) {
417450 // Tiled sharding requires multiple devices.
418451 EXPECT_TRUE (xla::protobuf_util::HaveSameSerialization (
419- tiled, sharding_specs[0 ]->sharding ));
452+ tiled, sharding_specs[0 ]->sharding . GetXlaOpSharding () ));
420453 } else {
421454 // Sincle device execution defaults to replication sharding.
422455 EXPECT_TRUE (xla::protobuf_util::HaveSameSerialization (
423- xla::HloSharding::Replicate ().ToProto (), sharding_specs[0 ]->sharding ));
456+ xla::HloSharding::Replicate ().ToProto (),
457+ sharding_specs[0 ]->sharding .GetXlaOpSharding ()));
424458 }
425459
426460 // Check if the placeholder is on a SPMD device (sharded) with no real values.
0 commit comments