@@ -51,7 +51,8 @@ TEST_F(XLAShardingTest, GetShardShape) {
51
51
{2 , 3 },
52
52
});
53
53
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
54
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
54
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 };
55
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
55
56
auto sharding_spec =
56
57
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
57
58
@@ -60,7 +61,7 @@ TEST_F(XLAShardingTest, GetShardShape) {
60
61
EXPECT_EQ (shard_shape, std::vector<int64_t >({4 , 4 }));
61
62
62
63
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
63
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
64
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
64
65
sharding_spec->sharding = sharding;
65
66
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
66
67
// For replicated sharding, each dimension should be preserved
@@ -78,7 +79,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
78
79
{2 , 3 },
79
80
});
80
81
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
81
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
82
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 };
83
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
82
84
auto sharding_spec =
83
85
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
84
86
auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
@@ -108,7 +110,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
108
110
}
109
111
}
110
112
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
111
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
113
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
112
114
sharding_spec->sharding = sharding;
113
115
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
114
116
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices (
@@ -126,6 +128,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
126
128
TEST_F (XLAShardingTest, ShardTensor) {
127
129
std::vector<std::string> devices = {" TPU:0" , " TPU:1" , " TPU:2" , " TPU:3" ,
128
130
" TPU:4" , " TPU:5" , " TPU:6" , " TPU:7" };
131
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
129
132
130
133
// 1D tiled
131
134
at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
@@ -136,7 +139,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
136
139
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ()),
137
140
devices.size ())
138
141
.ToProto ();
139
- torch_xla::OpSharding sharding (xla_sharding, std:: nullopt );
142
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment );
140
143
auto sharding_spec =
141
144
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
142
145
auto shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -155,7 +158,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
155
158
{4 , 5 , 6 , 7 },
156
159
});
157
160
xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
158
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
161
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
159
162
sharding_spec =
160
163
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
161
164
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -168,7 +171,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
168
171
// size should be smaller in dim=1 because it's not evenly divisible.
169
172
xla::Array3D<int64_t > cube ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}});
170
173
xla_sharding = xla::HloSharding::Tile (cube).ToProto ();
171
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
174
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
172
175
sharding_spec->sharding = sharding;
173
176
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
174
177
/* padded=*/ false );
@@ -178,7 +181,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
178
181
179
182
// Replicated, all shards should be identical.
180
183
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
181
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
184
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
182
185
sharding_spec->sharding = sharding;
183
186
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
184
187
/* padded=*/ false );
@@ -194,7 +197,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
194
197
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
195
198
xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
196
199
xla_sharding = xla::HloSharding::Tile (tesseract).ToProto ();
197
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
200
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
198
201
sharding_spec =
199
202
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
200
203
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -219,7 +222,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
219
222
xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
220
223
hypercube.FillIota (0 );
221
224
xla_sharding = xla::HloSharding::Tile (hypercube).ToProto ();
222
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
225
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
223
226
sharding_spec =
224
227
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
225
228
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -248,7 +251,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
248
251
{6 , 7 , 2 , 3 },
249
252
});
250
253
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
251
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
254
+ std::vector<int64_t > denormalized_tile_assignment = {4 , 5 , 0 , 1 , 6 , 7 , 2 , 3 };
255
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
252
256
auto sharding_spec =
253
257
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
254
258
// For devices at the start of the mesh, all shards should have the same
@@ -266,7 +270,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
266
270
{2 , 3 , 6 , 7 },
267
271
});
268
272
xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
269
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
273
+ denormalized_tile_assignment = {0 , 1 , 4 , 5 , 2 , 3 , 6 , 7 };
274
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment);
270
275
sharding_spec->sharding = sharding;
271
276
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
272
277
/* padded=*/ false );
@@ -295,7 +300,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
295
300
});
296
301
297
302
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
298
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
303
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
304
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
299
305
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
300
306
sharding, global_shape, /* minibatch=*/ true );
301
307
auto shards = ShardingUtil::ShardTensor (minibatch_tensor, sharding_spec,
@@ -314,14 +320,15 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
314
320
{4 , 5 , 6 , 7 },
315
321
})
316
322
.ToProto ();
317
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
323
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
324
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
318
325
XLATensor::ShardingSpec tiled_2d (sharding, tensor_shape);
319
326
xla_sharding =
320
327
xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto ();
321
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
328
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
322
329
XLATensor::ShardingSpec tiled_3d (sharding, tensor_shape);
323
330
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
324
- sharding = torch_xla::OpSharding (xla_sharding, std:: nullopt );
331
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
325
332
XLATensor::ShardingSpec replicated (sharding, tensor_shape);
326
333
EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_2d));
327
334
EXPECT_FALSE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_3d));
0 commit comments