Skip to content

Commit c0213e5

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
Implement tensor padding for local shards wrapper (#3382)
Summary: X-link: pytorch/pytorch#163183 This diff implements the constant padding functionality (aten.constant_pad_nd.default) for `LocalShardsWrapper`. The method applies constant padding to the local shards based on the provided padding specification. Depending on the sharding type (RW, CW), the padding on [left, right, top, bottom] directions will be either applied to the first/last shard, or all local shards. New unit tests cover: - 1D (RW) top/bottom paddings - 2D (CW) left, right, top, bottom paddings - empty shards, number of dimensions > 2 Differential Revision: D82663766
1 parent afa92f0 commit c0213e5

File tree

2 files changed

+533
-1
lines changed

2 files changed

+533
-1
lines changed

torchrec/distributed/shards_wrapper.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __new__(
7373
cat_tensor_shape[1] += shard.size()[1]
7474

7575
# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76-
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
76+
if len(local_shards) > 1 and local_shards[0].ndim == 1: # row-wise sharding
7777
for shard in local_shards[1:]:
7878
cat_tensor_shape[0] += shard.size()[0]
7979

@@ -119,6 +119,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119119
aten.copy_.default: cls.handle_copy_,
120120
aten.zeros_like.default: cls.handle_zeros_like,
121121
aten.empty_like.default: cls.handle_empty_like,
122+
aten.constant_pad_nd.default: cls.handle_constant_pad_nd,
122123
}
123124

124125
if func in dispatcher:
@@ -279,6 +280,208 @@ def handle_new_empty(args, kwargs):
279280
self_ls.local_offsets(),
280281
)
281282

283+
@staticmethod
284+
# pyre-fixme[3]: Return type must be annotated.
285+
# pyre-fixme[2]: Parameter must be annotated.
286+
def handle_constant_pad_nd(args, kwargs):
287+
"""
288+
Apply constant padding to LocalShardsWrapper.
289+
290+
The padding is based off of the following ideas:
291+
- The resulting wrapper represents the padded version of the logical tensor.
292+
- Each shard is padded based on the sharding type + dimension that is padded.
293+
- For instance, CW shards padded on the left most col will have only padding on the first CW shard.
294+
- Padding the top row will apply to all CW shards.
295+
"""
296+
self_lsw = args[0]
297+
pad_spec = args[1]
298+
pad_value = args[2] if len(args) > 2 else 0.0
299+
300+
if len(self_lsw.local_shards()) == 0:
301+
raise NotImplementedError(
302+
"Padding empty LocalShardsWrapper is not supported."
303+
)
304+
305+
local_shards = self_lsw.local_shards()
306+
307+
if len(local_shards) == 1:
308+
padded_shard = torch.nn.functional.pad(
309+
local_shards[0], pad_spec, mode="constant", value=pad_value
310+
)
311+
return LocalShardsWrapper([padded_shard], self_lsw.local_offsets())
312+
313+
padded_shards = list(local_shards)
314+
315+
if local_shards[0].ndim == 2:
316+
# 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
317+
if len(pad_spec) == 2:
318+
# Single dimension padding happens on the left most column
319+
pad_spec = pad_spec + [0, 0]
320+
321+
if len(pad_spec) != 4:
322+
raise ValueError(
323+
f"Padding spec must be of length 4 for 2D tensors, got {len(pad_spec)}"
324+
)
325+
326+
pad_left, pad_right, pad_top, pad_bottom = (
327+
pad_spec[0],
328+
pad_spec[1],
329+
pad_spec[2],
330+
pad_spec[3],
331+
)
332+
333+
if pad_top > 0:
334+
padded_shards = [
335+
torch.nn.functional.pad(
336+
shard, [0, 0, pad_top, 0], mode="constant", value=pad_value
337+
)
338+
for shard in padded_shards
339+
]
340+
if pad_bottom > 0:
341+
padded_shards = [
342+
torch.nn.functional.pad(
343+
shard, [0, 0, 0, pad_bottom], mode="constant", value=pad_value
344+
)
345+
for shard in padded_shards
346+
]
347+
if pad_left > 0:
348+
padded_shards[0] = torch.nn.functional.pad(
349+
padded_shards[0],
350+
[pad_left, 0, 0, 0],
351+
mode="constant",
352+
value=pad_value,
353+
)
354+
if pad_right > 0:
355+
padded_shards[-1] = torch.nn.functional.pad(
356+
padded_shards[-1],
357+
[0, pad_right, 0, 0],
358+
mode="constant",
359+
value=pad_value,
360+
)
361+
elif local_shards[0].ndim == 1:
362+
# 1D Row-wise sharding: [pad_top, pad_bottom]
363+
if len(pad_spec) != 2:
364+
raise ValueError(
365+
f"Padding spec must be of length 2 for 1D tensors, got {len(pad_spec)}"
366+
)
367+
pad_top, pad_bottom = pad_spec[0], pad_spec[1]
368+
369+
if pad_top > 0:
370+
padded_shards[0] = torch.nn.functional.pad(
371+
padded_shards[0], [pad_top, 0], mode="constant", value=pad_value
372+
)
373+
if pad_bottom > 0:
374+
padded_shards[-1] = torch.nn.functional.pad(
375+
padded_shards[-1], [0, pad_bottom], mode="constant", value=pad_value
376+
)
377+
else:
378+
raise NotImplementedError(
379+
f"Padding for {local_shards[0].ndim}D tensors is not supported. "
380+
f"Only 1D and 2D tensors are currently supported."
381+
)
382+
383+
# Update offsets and storage metadata
384+
original_storage = self_lsw.storage_metadata()
385+
updated_offsets, updated_storage = LocalShardsWrapper._compute_updated_metadata(
386+
original_storage,
387+
self_lsw.local_offsets(),
388+
pad_spec,
389+
local_shards[0].ndim,
390+
padded_shards,
391+
)
392+
393+
result = LocalShardsWrapper(padded_shards, updated_offsets)
394+
result._storage_meta = updated_storage
395+
return result
396+
397+
@staticmethod
398+
def _compute_updated_metadata(
399+
original_storage: TensorStorageMetadata,
400+
original_offsets: list[torch.Size],
401+
pad_spec: list[int],
402+
ndim: int,
403+
padded_shards: list[torch.Tensor],
404+
) -> tuple[list[tuple[int, ...]], TensorStorageMetadata]:
405+
"""
406+
Compute updated offsets and storage metadata after padding is applied.
407+
408+
Args:
409+
original_storage: Original storage metadata
410+
original_offsets: Original shard offsets
411+
pad_spec: Padding specification
412+
ndim: Number of dimensions (1=RW or 2=CW)
413+
padded_shards: Padded shard tensors
414+
415+
Returns:
416+
Tuple of (updated_offsets, updated_storage_metadata)
417+
"""
418+
if ndim == 1: # 1D RW
419+
pad_top, pad_bottom = pad_spec[0], pad_spec[1]
420+
421+
updated_offsets = []
422+
for i, offset in enumerate(original_offsets):
423+
if i == 0:
424+
# First shard: offset stays the same (absorbs top padding)
425+
updated_offsets.append(tuple(offset))
426+
else:
427+
# Subsequent shards: shift by top padding amount
428+
new_offset = (offset[0] + pad_top,)
429+
updated_offsets.append(new_offset)
430+
431+
new_global_size = torch.Size(
432+
[original_storage.size[0] + pad_top + pad_bottom]
433+
)
434+
435+
elif ndim == 2: # 2D CW
436+
pad_left, pad_right, pad_top, pad_bottom = (
437+
pad_spec[0],
438+
pad_spec[1],
439+
pad_spec[2],
440+
pad_spec[3],
441+
)
442+
443+
updated_offsets = []
444+
for i, offset in enumerate(original_offsets):
445+
row_offset = offset[0]
446+
col_offset = offset[1]
447+
448+
# Top/bottom padding doesn't affect offsets
449+
# Left padding affects column offsets
450+
if i == 0:
451+
# First shard: column offset stays the same (absorbs left padding)
452+
new_2d_offset = (row_offset, col_offset)
453+
else:
454+
# Subsequent shards: shift column offset by left padding amount
455+
new_2d_offset = (row_offset, col_offset + pad_left)
456+
457+
updated_offsets.append(new_2d_offset)
458+
459+
new_global_size = torch.Size(
460+
[
461+
original_storage.size[0] + pad_top + pad_bottom,
462+
original_storage.size[1] + pad_left + pad_right,
463+
]
464+
)
465+
466+
else:
467+
raise NotImplementedError(f"Metadata computation for {ndim}D not supported")
468+
469+
updated_chunks = [
470+
ChunkStorageMetadata(
471+
offsets=torch.Size(offset),
472+
sizes=shard.size(),
473+
)
474+
for offset, shard in zip(updated_offsets, padded_shards)
475+
]
476+
477+
updated_storage = TensorStorageMetadata(
478+
properties=original_storage.properties,
479+
size=new_global_size,
480+
chunks=updated_chunks,
481+
)
482+
483+
return updated_offsets, updated_storage
484+
282485
@property
283486
def device(self) -> torch._C.device: # type: ignore[override]
284487
return (

0 commit comments

Comments
 (0)