@@ -42,6 +42,34 @@ class XLAShardingTest : public AtenXlaTensorTestBase {
42
42
}
43
43
};
44
44
45
+ TEST_F (XLAShardingTest, NormalizeTileAssignment) {
46
+ // Test with an empty tile assignment
47
+ std::vector<int64_t > empty_tile_assignment = {};
48
+ auto normalized =
49
+ ShardingUtil::NormalizeTileAssignment (empty_tile_assignment);
50
+ EXPECT_TRUE (normalized.empty ());
51
+
52
+ // Test with positive values
53
+ std::vector<int64_t > positive_tile_assignment = {3 , 1 , 4 , 2 };
54
+ normalized = ShardingUtil::NormalizeTileAssignment (positive_tile_assignment);
55
+ EXPECT_EQ (normalized, std::vector<int64_t >({2 , 0 , 3 , 1 }));
56
+
57
+ // Test with all identical values
58
+ std::vector<int64_t > identical_tile_assignment = {5 , 5 , 5 , 5 };
59
+ normalized = ShardingUtil::NormalizeTileAssignment (identical_tile_assignment);
60
+ EXPECT_EQ (normalized, std::vector<int64_t >({0 , 0 , 0 , 0 }));
61
+
62
+ // Test with negative values
63
+ std::vector<int64_t > negative_tile_assignment = {-3 , -1 , -4 , -2 };
64
+ EXPECT_THROW (ShardingUtil::NormalizeTileAssignment (negative_tile_assignment),
65
+ std::runtime_error);
66
+
67
+ // Test with mixed positive and negative values
68
+ std::vector<int64_t > mixed_tile_assignment = {3 , -1 , 4 , 2 };
69
+ EXPECT_THROW (ShardingUtil::NormalizeTileAssignment (mixed_tile_assignment),
70
+ std::runtime_error);
71
+ }
72
+
45
73
TEST_F (XLAShardingTest, GetShardShape) {
46
74
auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
47
75
xla::Shape tensor_shape =
@@ -51,7 +79,8 @@ TEST_F(XLAShardingTest, GetShardShape) {
51
79
{2 , 3 },
52
80
});
53
81
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
54
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
82
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 };
83
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
55
84
auto sharding_spec =
56
85
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
57
86
@@ -60,7 +89,7 @@ TEST_F(XLAShardingTest, GetShardShape) {
60
89
EXPECT_EQ (shard_shape, std::vector<int64_t >({4 , 4 }));
61
90
62
91
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
63
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
92
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
64
93
sharding_spec->sharding = sharding;
65
94
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
66
95
// For replicated sharding, each dimension should be preserved
@@ -78,7 +107,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
78
107
{2 , 3 },
79
108
});
80
109
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
81
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
110
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 };
111
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
82
112
auto sharding_spec =
83
113
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
84
114
auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
@@ -108,7 +138,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
108
138
}
109
139
}
110
140
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
111
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
141
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
112
142
sharding_spec->sharding = sharding;
113
143
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
114
144
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices (
@@ -126,6 +156,7 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
126
156
TEST_F (XLAShardingTest, ShardTensor) {
127
157
std::vector<std::string> devices = {" TPU:0" , " TPU:1" , " TPU:2" , " TPU:3" ,
128
158
" TPU:4" , " TPU:5" , " TPU:6" , " TPU:7" };
159
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
129
160
130
161
// 1D tiled
131
162
at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
@@ -136,7 +167,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
136
167
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ()),
137
168
devices.size ())
138
169
.ToProto ();
139
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
170
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment );
140
171
auto sharding_spec =
141
172
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
142
173
auto shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -155,7 +186,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
155
186
{4 , 5 , 6 , 7 },
156
187
});
157
188
xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
158
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
189
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
159
190
sharding_spec =
160
191
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
161
192
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -168,7 +199,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
168
199
// size should be smaller in dim=1 because it's not evenly divisible.
169
200
xla::Array3D<int64_t > cube ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}});
170
201
xla_sharding = xla::HloSharding::Tile (cube).ToProto ();
171
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
202
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
172
203
sharding_spec->sharding = sharding;
173
204
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
174
205
/* padded=*/ false );
@@ -178,7 +209,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
178
209
179
210
// Replicated, all shards should be identical.
180
211
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
181
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
212
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
182
213
sharding_spec->sharding = sharding;
183
214
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
184
215
/* padded=*/ false );
@@ -194,7 +225,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
194
225
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
195
226
xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
196
227
xla_sharding = xla::HloSharding::Tile (tesseract).ToProto ();
197
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
228
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
198
229
sharding_spec =
199
230
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
200
231
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -219,7 +250,7 @@ TEST_F(XLAShardingTest, ShardTensor) {
219
250
xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
220
251
hypercube.FillIota (0 );
221
252
xla_sharding = xla::HloSharding::Tile (hypercube).ToProto ();
222
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
253
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
223
254
sharding_spec =
224
255
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
225
256
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -248,7 +279,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
248
279
{6 , 7 , 2 , 3 },
249
280
});
250
281
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
251
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
282
+ std::vector<int64_t > denormalized_tile_assignment = {4 , 5 , 0 , 1 , 6 , 7 , 2 , 3 };
283
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
252
284
auto sharding_spec =
253
285
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
254
286
// For devices at the start of the mesh, all shards should have the same
@@ -266,7 +298,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
266
298
{2 , 3 , 6 , 7 },
267
299
});
268
300
xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
269
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt);
301
+ denormalized_tile_assignment = {0 , 1 , 4 , 5 , 2 , 3 , 6 , 7 };
302
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment);
270
303
sharding_spec->sharding = sharding;
271
304
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
272
305
/* padded=*/ false );
@@ -295,7 +328,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
295
328
});
296
329
297
330
auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
298
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
331
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
332
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
299
333
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
300
334
sharding, global_shape, /* minibatch=*/ true );
301
335
auto shards = ShardingUtil::ShardTensor (minibatch_tensor, sharding_spec,
@@ -314,14 +348,15 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
314
348
{4 , 5 , 6 , 7 },
315
349
})
316
350
.ToProto ();
317
- torch_xla::OpSharding sharding (xla_sharding, std::nullopt);
351
+ std::vector<int64_t > denormalized_tile_assignment = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
352
+ torch_xla::OpSharding sharding (xla_sharding, denormalized_tile_assignment);
318
353
XLATensor::ShardingSpec tiled_2d (sharding, tensor_shape);
319
354
xla_sharding =
320
355
xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto ();
321
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
356
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
322
357
XLATensor::ShardingSpec tiled_3d (sharding, tensor_shape);
323
358
xla_sharding = xla::HloSharding::Replicate ().ToProto ();
324
- sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
359
+ sharding = torch_xla::OpSharding (xla_sharding, denormalized_tile_assignment );
325
360
XLATensor::ShardingSpec replicated (sharding, tensor_shape);
326
361
EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_2d));
327
362
EXPECT_FALSE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_3d));
0 commit comments