Skip to content

Commit 30165d5

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
[torchrec][LocalShardsWrapper] Implement tensor padding for local shards wrapper (pytorch#163183)
Summary: X-link: pytorch/torchrec#3382 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Test Plan: ``` buck2 test fbcode//caffe2/test/distributed/tensor:shards_wrapper 2025-09-18T15:32:46.525914Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-aarch64-compat 2025-09-18T15:32:46.525953Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-compat 2025-09-18T15:32:46.525959Z WARN buck2_interpreter_for_build::interpreter::functions::warning: ptxas 12.8 is not available on platform platform010-libcxx Buck UI: https://www.internalfb.com/buck2/ffb34bcb-1555-4fa3-89c6-9c22d078606a Test UI: https://www.internalfb.com/intern/testinfra/testrun/12384899087608299 Network: Up: 159MiB Down: 13GiB (reSessionID-f734bd3c-19ca-44c9-919f-57203ac00be8) Loading targets. Remaining 0/5110 104336 dirs read, 1265395 targets declared Analyzing targets. Remaining 0/80346 3349033 actions, 4142832 artifacts declared Executing actions. Remaining 0/521855 149:06:17.8s exec time total Command: test. Finished 14 local, 397 remote, 199840 cache (99% hit) 148:27:40.9s exec time cached (99%) Time elapsed: 8:55.5s Tests finished: Pass 14. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` Differential Revision: D82663766
1 parent 4967ad8 commit 30165d5

File tree

2 files changed

+574
-0
lines changed

2 files changed

+574
-0
lines changed
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates
2+
# Owner(s): ["oncall: distributed"]
3+
import torch
4+
from torch.distributed.tensor._shards_wrapper import LocalShardsWrapper
5+
from torch.testing._internal.common_utils import run_tests, TestCase
6+
7+
8+
class LocalShardsWrapperPaddingTest(TestCase):
9+
"""Test cases for constant padding functionality in LocalShardsWrapper."""
10+
11+
def test_empty_shards_padding(self) -> None:
12+
"""Test padding with empty shards list."""
13+
lsw = LocalShardsWrapper([], [])
14+
pad_spec = [1, 2, 3, 4]
15+
pad_value = 5.0
16+
17+
self.assertRaises(
18+
Exception,
19+
torch.ops.aten.constant_pad_nd.default,
20+
lsw,
21+
pad_spec,
22+
pad_value,
23+
)
24+
25+
def test_invalid_1d_rw_padding(self) -> None:
26+
"""Test invalid padding on 1D tensor throws ValueError."""
27+
shard1 = torch.tensor([1.0, 2.0])
28+
shard2 = torch.tensor([3.0, 4.0])
29+
lsw = LocalShardsWrapper([shard1, shard2], [(2, 0)])
30+
pad_spec = [1] # invalid padding spec
31+
pad_value = 5.0
32+
33+
self.assertRaises(
34+
ValueError,
35+
torch.ops.aten.constant_pad_nd.default,
36+
lsw,
37+
pad_spec,
38+
pad_value,
39+
)
40+
41+
def test_invalid_2d_cw_padding(self) -> None:
42+
"""Test invalid padding on 2D tensor throws ValueError."""
43+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
44+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
45+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
46+
pad_spec = [1, 2, 3] # invalid padding spec
47+
pad_value = 5.0
48+
49+
self.assertRaises(
50+
ValueError,
51+
torch.ops.aten.constant_pad_nd.default,
52+
lsw,
53+
pad_spec,
54+
pad_value,
55+
)
56+
57+
pad_spec = [1]
58+
59+
self.assertRaises(
60+
ValueError,
61+
torch.ops.aten.constant_pad_nd.default,
62+
lsw,
63+
pad_spec,
64+
pad_value,
65+
)
66+
67+
def test_single_shard_padding_2d(self) -> None:
68+
"""Test padding with single 2D shard."""
69+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
70+
lsw = LocalShardsWrapper([tensor], [(0, 0)])
71+
pad_spec = [1, 2, 3, 4] # [left=1, right=2, top=3, bottom=4]
72+
pad_value = 0.0
73+
74+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
75+
76+
expected = torch.nn.functional.pad(
77+
tensor, pad_spec, mode="constant", value=pad_value
78+
)
79+
self.assertIsInstance(result, LocalShardsWrapper)
80+
self.assertEqual(len(result.local_shards()), 1)
81+
torch.testing.assert_close(result.local_shards()[0], expected)
82+
83+
def test_single_shard_padding_1d(self) -> None:
84+
"""Test padding with single 1D shard."""
85+
tensor = torch.tensor([1.0, 2.0, 3.0])
86+
lsw = LocalShardsWrapper([tensor], [(0,)])
87+
pad_spec = [2, 1] # [top=2, bottom=1]
88+
pad_value = -1.0
89+
90+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
91+
92+
self.assertIsInstance(result, LocalShardsWrapper)
93+
self.assertEqual(len(result.local_shards()), 1)
94+
95+
expected = torch.nn.functional.pad(
96+
tensor, pad_spec, mode="constant", value=pad_value
97+
)
98+
torch.testing.assert_close(result.local_shards()[0], expected)
99+
100+
def test_2d_cw_sharding_top_padding(self) -> None:
101+
"""Test column-wise sharding with top padding (affects all shards)."""
102+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
103+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
104+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
105+
pad_spec = [0, 0, 2, 0] # top=2
106+
pad_value = 0.0
107+
108+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
109+
110+
self.assertEqual(len(result.local_shards()), 2)
111+
# Both shards should have 2 rows added at top
112+
expected_shape = (4, 2)
113+
self.assertEqual(result.local_shards()[0].shape, expected_shape)
114+
self.assertEqual(result.local_shards()[1].shape, expected_shape)
115+
116+
torch.testing.assert_close(result.local_shards()[0][:2], torch.zeros(2, 2))
117+
torch.testing.assert_close(result.local_shards()[1][:2], torch.zeros(2, 2))
118+
torch.testing.assert_close(result.local_shards()[0][2:], shard1)
119+
torch.testing.assert_close(result.local_shards()[1][2:], shard2)
120+
121+
def test_2d_cw_sharding_bottom_padding(self) -> None:
122+
"""Test column-wise sharding with bottom padding (affects all shards)."""
123+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
124+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
125+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
126+
pad_spec = [0, 0, 0, 1] # bottom=1
127+
pad_value = -1.0
128+
129+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
130+
131+
self.assertEqual(len(result.local_shards()), 2)
132+
expected_shape = (3, 2)
133+
self.assertEqual(result.local_shards()[0].shape, expected_shape)
134+
self.assertEqual(result.local_shards()[1].shape, expected_shape)
135+
136+
torch.testing.assert_close(result.local_shards()[0][:2], shard1)
137+
torch.testing.assert_close(result.local_shards()[1][:2], shard2)
138+
torch.testing.assert_close(
139+
result.local_shards()[0][2:], torch.full((1, 2), -1.0)
140+
)
141+
torch.testing.assert_close(
142+
result.local_shards()[1][2:], torch.full((1, 2), -1.0)
143+
)
144+
145+
def test_2d_cw_sharding_left_padding(self) -> None:
146+
"""Test column-wise sharding with left padding (affects first shard only)."""
147+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
148+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
149+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
150+
pad_spec = [3, 0, 0, 0] # left=3
151+
pad_value = 2.0
152+
153+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
154+
155+
self.assertEqual(len(result.local_shards()), 2)
156+
# First shard should have 3 columns added at left
157+
self.assertEqual(result.local_shards()[0].shape, (2, 5))
158+
self.assertEqual(result.local_shards()[1].shape, (2, 2))
159+
160+
# Check content
161+
torch.testing.assert_close(
162+
result.local_shards()[0][:, :3], torch.full((2, 3), 2.0)
163+
)
164+
torch.testing.assert_close(result.local_shards()[0][:, 3:], shard1)
165+
torch.testing.assert_close(result.local_shards()[1], shard2)
166+
167+
def test_2d_cw_sharding_right_padding(self) -> None:
168+
"""Test column-wise sharding with right padding (affects last shard only)."""
169+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
170+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
171+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
172+
pad_spec = [0, 2, 0, 0] # right=2
173+
pad_value = 3.0
174+
175+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
176+
177+
# Second shard should have 2 columns added at right
178+
expected_shard_1 = torch.tensor(
179+
[[1.0, 2.0],
180+
[5.0, 6.0]]
181+
)
182+
expected_shard_2 = torch.tensor(
183+
[[3.0, 4.0, 3.0, 3.0],
184+
[7.0, 8.0, 3.0, 3.0]]
185+
)
186+
self.assertEqual(len(result.local_shards()), 2)
187+
torch.testing.assert_close(result.local_shards()[0], expected_shard_1)
188+
torch.testing.assert_close(result.local_shards()[1], expected_shard_2)
189+
190+
# 1D padding on 2D pads the last dimension
191+
pad_spec_2 = [0, 2] # right=2
192+
result_2 = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec_2, pad_value)
193+
torch.testing.assert_close(result_2.local_shards()[0], expected_shard_1)
194+
torch.testing.assert_close(result_2.local_shards()[1], expected_shard_2)
195+
196+
def test_2d_cw_sharding_mixed_padding(self) -> None:
197+
"""Test column-wise sharding with mixed padding directions."""
198+
shard1 = torch.tensor([[1.0, 2.0], [5.0, 6.0]])
199+
shard2 = torch.tensor([[3.0, 4.0], [7.0, 8.0]])
200+
lsw = LocalShardsWrapper([shard1, shard2], [(0, 0), (0, 2)])
201+
pad_spec = [1, 2, 1, 1] # [left=1, right=2, top=1, bottom=1]
202+
pad_value = 0.0
203+
204+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
205+
206+
expected_shard_1 = torch.tensor(
207+
[[0.0, 0.0, 0.0],
208+
[0.0, 1.0, 2.0],
209+
[0.0, 5.0, 6.0],
210+
[0.0, 0.0, 0.0]],
211+
)
212+
213+
expected_shard_2 = torch.tensor(
214+
[[0.0, 0.0, 0.0, 0.0],
215+
[3.0, 4.0, 0.0, 0.0],
216+
[7.0, 8.0, 0.0, 0.0],
217+
[0.0, 0.0, 0.0, 0.0]],
218+
)
219+
220+
self.assertEqual(len(result.local_shards()), 2)
221+
torch.testing.assert_close(result.local_shards()[0], expected_shard_1)
222+
torch.testing.assert_close(result.local_shards()[1], expected_shard_2)
223+
224+
def test_1d_rw_sharding_top_padding(self) -> None:
225+
"""Test row-wise sharding with top padding (affects first shard only)."""
226+
shard1 = torch.tensor([1.0, 2.0, 3.0])
227+
shard2 = torch.tensor([4.0, 5.0, 6.0])
228+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)])
229+
pad_spec = [2, 0] # top=2
230+
pad_value = 0.0
231+
232+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
233+
234+
expected_shard_1 = torch.tensor(
235+
[0.0, 0.0, 1.0, 2.0, 3.0],
236+
)
237+
expected_shard_2 = torch.tensor(
238+
[4.0, 5.0, 6.0],
239+
)
240+
241+
self.assertEqual(len(result.local_shards()), 2)
242+
torch.testing.assert_close(result.local_shards()[0], expected_shard_1)
243+
torch.testing.assert_close(result.local_shards()[1], expected_shard_2)
244+
245+
def test_1d_rw_sharding_bottom_padding(self) -> None:
246+
"""Test row-wise sharding with bottom padding (affects last shard only)."""
247+
shard1 = torch.tensor([1.0, 2.0, 3.0])
248+
shard2 = torch.tensor([4.0, 5.0, 6.0])
249+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (3,)])
250+
pad_spec = [0, 1] # bottom=1
251+
pad_value = -1.0
252+
253+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
254+
255+
expected_shard_1 = torch.tensor(
256+
[1.0, 2.0, 3.0],
257+
)
258+
expected_shard_2 = torch.tensor(
259+
[4.0, 5.0, 6.0, -1.0],
260+
)
261+
262+
self.assertEqual(len(result.local_shards()), 2)
263+
torch.testing.assert_close(result.local_shards()[0], expected_shard_1)
264+
torch.testing.assert_close(result.local_shards()[1], expected_shard_2)
265+
266+
def test_1d_rw_sharding_mixed_padding(self) -> None:
267+
"""Test row-wise sharding with mixed top/bottom padding."""
268+
shard1 = torch.tensor([1.0, 2.0])
269+
shard2 = torch.tensor([3.0, 4.0])
270+
lsw = LocalShardsWrapper([shard1, shard2], [(0,), (2,)])
271+
pad_spec = [1, 2] # [top=1, bottom=2]
272+
pad_value = 5.0
273+
274+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
275+
276+
expected_shard_1 = torch.tensor(
277+
[5.0, 1.0, 2.0],
278+
)
279+
expected_shard_2 = torch.tensor(
280+
[3.0, 4.0, 5.0, 5.0],
281+
)
282+
283+
self.assertEqual(len(result.local_shards()), 2)
284+
torch.testing.assert_close(result.local_shards()[0], expected_shard_1)
285+
torch.testing.assert_close(result.local_shards()[1], expected_shard_2)
286+
287+
def test_higher_dimensions_not_implemented(self) -> None:
288+
"""Test that higher dimensional tensors raise NotImplementedError."""
289+
tensor_3d = torch.rand(2, 3, 4) # 3D tensor
290+
lsw = LocalShardsWrapper([tensor_3d, tensor_3d], [(0, 0, 0), (2, 0, 0)])
291+
pad_spec = [1, 1, 1, 1, 1, 1] # 3D padding spec
292+
pad_value = 0.0
293+
294+
with self.assertRaises(NotImplementedError) as cm:
295+
torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, pad_value)
296+
297+
self.assertIn("3D tensors is not supported", str(cm.exception))
298+
self.assertIn(
299+
"Only 1D and 2D tensors are currently supported", str(cm.exception)
300+
)
301+
302+
def test_offsets_and_storage_metadata_after_padding_1d_rw(self) -> None:
303+
# Test 1D RW sharding with top+bottom padding
304+
shard1 = torch.tensor([1.0, 2.0])
305+
shard2 = torch.tensor([3.0, 4.0])
306+
original_offsets = [(0,), (2,)]
307+
lsw = LocalShardsWrapper([shard1, shard2], original_offsets)
308+
309+
# Check original storage metadata
310+
original_storage = lsw.storage_metadata()
311+
self.assertEqual(original_storage.size, torch.Size([4])) # [1,2,3,4]
312+
self.assertEqual(len(original_storage.chunks), 2)
313+
self.assertEqual(original_storage.chunks[0].offsets, torch.Size([0]))
314+
self.assertEqual(original_storage.chunks[0].sizes, torch.Size([2]))
315+
self.assertEqual(original_storage.chunks[1].offsets, torch.Size([2]))
316+
self.assertEqual(original_storage.chunks[1].sizes, torch.Size([2]))
317+
318+
pad_spec = [1, 1] # add 1 element at top and bottom
319+
result = torch.ops.aten.constant_pad_nd.default(lsw, pad_spec, 0.0)
320+
321+
expected_offsets = [
322+
torch.Size([0]),
323+
torch.Size([3]),
324+
] # Second shard's offset shifted by 1
325+
self.assertEqual(result.local_offsets(), expected_offsets)
326+
327+
result_storage = result.storage_metadata()
328+
329+
# Global tensor should be: [0, 1, 2, 3, 4, 0] shape=[6]
330+
expected_global_size = torch.Size([6])
331+
self.assertEqual(result_storage.size, expected_global_size)
332+
333+
self.assertEqual(len(result_storage.chunks), 2)
334+
335+
# First chunk: [3] elements at offset [0] (size increased by top padding)
336+
# Second chunk: [3] elements at offset [3] (size increased by bottom padding, offset shifted)
337+
self.assertEqual(result_storage.chunks[0].offsets, torch.Size([0]))
338+
self.assertEqual(result_storage.chunks[0].sizes, torch.Size([3]))
339+
self.assertEqual(result_storage.chunks[1].offsets, torch.Size([3]))
340+
self.assertEqual(result_storage.chunks[1].sizes, torch.Size([3]))
341+
342+
def test_offsets_and_storage_metadata_after_padding_2d_cw(self) -> None:
343+
# Test 2D CW sharding with left+right padding
344+
shard1_2d = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) # [2, 2] columns 0-1
345+
shard2_2d = torch.tensor([[3.0, 4.0], [7.0, 8.0]]) # [2, 2] columns 2-3
346+
original_offsets_2d = [(0, 0), (0, 2)]
347+
lsw_2d = LocalShardsWrapper([shard1_2d, shard2_2d], original_offsets_2d)
348+
349+
pad_spec_2d = [1, 1, 0, 0] # [left=1, right=1, top=0, bottom=0]
350+
result_2d = torch.ops.aten.constant_pad_nd.default(lsw_2d, pad_spec_2d, 0.0)
351+
352+
expected_offsets_2d = [
353+
torch.Size([0, 0]),
354+
torch.Size([0, 3]),
355+
]
356+
self.assertEqual(result_2d.local_offsets(), expected_offsets_2d)
357+
358+
result_storage_2d = result_2d.storage_metadata()
359+
360+
# Global tensor should go from [2,4] to [2,6] (add 1 left + 1 right)
361+
expected_global_size_2d = torch.Size([2, 6]) # [2, 4+1+1]
362+
self.assertEqual(result_storage_2d.size, expected_global_size_2d)
363+
364+
# First chunk: [2,3] at offset [0,0] (size increased by left padding)
365+
# Second chunk: [2,3] at offset [0,3] (size increased by right padding, offset shifted)
366+
self.assertEqual(result_storage_2d.chunks[0].offsets, torch.Size([0, 0]))
367+
self.assertEqual(result_storage_2d.chunks[0].sizes, torch.Size([2, 3]))
368+
self.assertEqual(result_storage_2d.chunks[1].offsets, torch.Size([0, 3]))
369+
self.assertEqual(result_storage_2d.chunks[1].sizes, torch.Size([2, 3]))
370+
371+
372+
if __name__ == "__main__":
373+
run_tests()

0 commit comments

Comments
 (0)