@@ -49,15 +49,18 @@ TEST_F(XLAShardingTest, GetShardShape) {
49
49
{0 , 1 },
50
50
{2 , 3 },
51
51
});
52
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
52
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
53
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
53
54
auto sharding_spec =
54
55
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
55
56
56
57
auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
57
58
// For tiled sharding, each dimension should be halved
58
59
EXPECT_EQ (shard_shape, std::vector<int64_t >({4 , 4 }));
59
60
60
- sharding_spec->sharding = xla::HloSharding::Replicate ().ToProto ();
61
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
62
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
63
+ sharding_spec->sharding = sharding;
61
64
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
62
65
// For replicated sharding, each dimension should be preserved
63
66
EXPECT_EQ (shard_shape, std::vector<int64_t >({8 , 7 }));
@@ -73,7 +76,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
73
76
{0 , 1 },
74
77
{2 , 3 },
75
78
});
76
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
79
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
80
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
77
81
auto sharding_spec =
78
82
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
79
83
auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
@@ -102,7 +106,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
102
106
EXPECT_EQ (slice.step (), 1 );
103
107
}
104
108
}
105
- sharding = xla::HloSharding::Replicate ().ToProto ();
109
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
110
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
106
111
sharding_spec->sharding = sharding;
107
112
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
108
113
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices (
@@ -125,11 +130,12 @@ TEST_F(XLAShardingTest, ShardTensor) {
125
130
at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
126
131
xla::Shape tensor_shape =
127
132
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
128
- xla::OpSharding sharding =
133
+ xla::OpSharding xla_sharding =
129
134
xla::HloSharding::Tile1D (
130
135
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ()),
131
136
devices.size ())
132
137
.ToProto ();
138
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
133
139
auto sharding_spec =
134
140
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
135
141
auto shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -147,7 +153,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
147
153
{0 , 1 , 2 , 3 },
148
154
{4 , 5 , 6 , 7 },
149
155
});
150
- sharding = xla::HloSharding::Tile (mesh).ToProto ();
156
+ xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
157
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
151
158
sharding_spec =
152
159
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
153
160
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -159,15 +166,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
159
166
// 3D tiled, the first dim is replicated and the last halved. The last shard
160
167
// size should be smaller in dim=1 because it's not evenly divisible.
161
168
xla::Array3D<int64_t > cube ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}});
162
- sharding_spec->sharding = xla::HloSharding::Tile (cube).ToProto ();
169
+ xla_sharding = xla::HloSharding::Tile (cube).ToProto ();
170
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
171
+ sharding_spec->sharding = sharding;
163
172
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
164
173
/* padded=*/ false );
165
174
EXPECT_EQ (shards.size (), 8 );
166
175
EXPECT_EQ (shards[0 ].sizes (), c10::ArrayRef<long >({8 , 2 , 2 }));
167
176
EXPECT_EQ (shards[7 ].sizes (), c10::ArrayRef<long >({8 , 1 , 2 }));
168
177
169
178
// Replicated, all shards should be identical.
170
- sharding_spec->sharding = xla::HloSharding::Replicate ().ToProto ();
179
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
180
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
181
+ sharding_spec->sharding = sharding;
171
182
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
172
183
/* padded=*/ false );
173
184
EXPECT_EQ (shards.size (), 8 );
@@ -181,7 +192,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
181
192
tensor_shape =
182
193
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
183
194
xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
184
- sharding = xla::HloSharding::Tile (tesseract).ToProto ();
195
+ xla_sharding = xla::HloSharding::Tile (tesseract).ToProto ();
196
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
185
197
sharding_spec =
186
198
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
187
199
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -205,7 +217,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
205
217
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
206
218
xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
207
219
hypercube.FillIota (0 );
208
- sharding = xla::HloSharding::Tile (hypercube).ToProto ();
220
+ xla_sharding = xla::HloSharding::Tile (hypercube).ToProto ();
221
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
209
222
sharding_spec =
210
223
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
211
224
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -233,7 +246,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
233
246
{4 , 5 , 0 , 1 },
234
247
{6 , 7 , 2 , 3 },
235
248
});
236
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
249
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
250
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
237
251
auto sharding_spec =
238
252
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
239
253
// For devices at the start of the mesh, all shards should have the same
@@ -250,7 +264,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
250
264
{0 , 1 , 4 , 5 },
251
265
{2 , 3 , 6 , 7 },
252
266
});
253
- sharding_spec->sharding = xla::HloSharding::Tile (mesh).ToProto ();
267
+ xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
268
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
269
+ sharding_spec->sharding = sharding;
254
270
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
255
271
/* padded=*/ false );
256
272
EXPECT_EQ (shards.size (), 4 );
@@ -277,7 +293,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
277
293
{{7 }},
278
294
});
279
295
280
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
296
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
297
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
281
298
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
282
299
sharding, global_shape, /* minibatch=*/ true );
283
300
auto shards = ShardingUtil::ShardTensor (minibatch_tensor, sharding_spec,
@@ -291,17 +308,20 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
291
308
auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
292
309
xla::Shape tensor_shape =
293
310
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
294
- XLATensor::ShardingSpec tiled_2d (xla::HloSharding::Tile ({
295
- {0 , 1 , 2 , 3 },
296
- {4 , 5 , 6 , 7 },
297
- })
298
- .ToProto (),
299
- tensor_shape);
300
- XLATensor::ShardingSpec tiled_3d (
301
- xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto (),
302
- tensor_shape);
303
- XLATensor::ShardingSpec replicated (xla::HloSharding::Replicate ().ToProto (),
304
- tensor_shape);
311
+ auto xla_sharding = xla::HloSharding::Tile ({
312
+ {0 , 1 , 2 , 3 },
313
+ {4 , 5 , 6 , 7 },
314
+ })
315
+ .ToProto ();
316
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
317
+ XLATensor::ShardingSpec tiled_2d (sharding, tensor_shape);
318
+ xla_sharding =
319
+ xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto ();
320
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
321
+ XLATensor::ShardingSpec tiled_3d (sharding, tensor_shape);
322
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
323
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
324
+ XLATensor::ShardingSpec replicated (sharding, tensor_shape);
305
325
EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_2d));
306
326
EXPECT_FALSE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_3d));
307
327
EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (replicated, replicated));
@@ -322,12 +342,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
322
342
std::vector<std::string> devices (3 );
323
343
std::fill_n (devices.begin (), devices.size (),
324
344
bridge::GetDefaultDevice ()->toString ());
345
+ auto replicate_xla_sharding = xla::HloSharding::Replicate ().ToProto ();
346
+ auto unknown_xla_sharding = xla::HloSharding::Unknown ().ToProto ();
347
+ torch_xla::OpSharding replicate_sharding (replicate_xla_sharding,
348
+ std::nullopt);
349
+ torch_xla::OpSharding unknown_sharding (unknown_xla_sharding, std::nullopt);
325
350
std::vector<XLATensor::ShardingSpecPtr> shardings = {
326
351
nullptr ,
327
- std::make_shared<XLATensor::ShardingSpec>(
328
- xla::HloSharding::Replicate (). ToProto (), tensor_shape),
329
- std::make_shared<XLATensor::ShardingSpec>(
330
- xla::HloSharding::Unknown (). ToProto (), tensor_shape)};
352
+ std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
353
+ tensor_shape),
354
+ std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
355
+ tensor_shape)};
331
356
std::vector<torch::lazy::BackendDataPtr> tensors_data =
332
357
CreateTensorsData (tensors, shardings, devices);
333
358
@@ -386,13 +411,21 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
386
411
auto y = xla::Add (x, xla::ConstantR0<float >(&b, 3 ));
387
412
xla::XlaComputation xla_computation =
388
413
ConsumeValue (b.Build (/* remove_dynamic_dimensions=*/ false ));
414
+
415
+ std::vector<torch::lazy::BackendDataPtr> parameters_data;
416
+ parameters_data.push_back (
417
+ torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
418
+ bridge::GetDefaultDevice ()->toString (), std::move (shape)));
419
+
389
420
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
390
421
instances.push_back ({std::move (xla_computation),
391
422
bridge::GetDefaultDevice ()->toString (),
392
423
{bridge::GetDefaultDevice ()->toString ()},
393
424
&shape,
394
425
/* should_wrap_parameter=*/ false ,
395
- /* is_sharded=*/ true });
426
+ /* is_sharded=*/ true ,
427
+ /* allow_spmd_sharding_propagation_to_output=*/ true ,
428
+ /* parameters_data=*/ parameters_data});
396
429
397
430
std::vector<
398
431
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -416,11 +449,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
416
449
if (n_devices > 1 ) {
417
450
// Tiled sharding requires multiple devices.
418
451
EXPECT_TRUE (xla::protobuf_util::HaveSameSerialization (
419
- tiled, sharding_specs[0 ]->sharding ));
452
+ tiled, sharding_specs[0 ]->sharding . GetXlaOpSharding () ));
420
453
} else {
421
454
// Sincle device execution defaults to replication sharding.
422
455
EXPECT_TRUE (xla::protobuf_util::HaveSameSerialization (
423
- xla::HloSharding::Replicate ().ToProto (), sharding_specs[0 ]->sharding ));
456
+ xla::HloSharding::Replicate ().ToProto (),
457
+ sharding_specs[0 ]->sharding .GetXlaOpSharding ()));
424
458
}
425
459
426
460
// Check if the placeholder is on a SPMD device (sharded) with no real values.
0 commit comments