@@ -73,7 +73,7 @@ def __new__(
73
73
cat_tensor_shape [1 ] += shard .size ()[1 ]
74
74
75
75
# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76
- if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # column -wise sharding
76
+ if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # row -wise sharding
77
77
for shard in local_shards [1 :]:
78
78
cat_tensor_shape [0 ] += shard .size ()[0 ]
79
79
@@ -119,6 +119,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119
119
aten .copy_ .default : cls .handle_copy_ ,
120
120
aten .zeros_like .default : cls .handle_zeros_like ,
121
121
aten .empty_like .default : cls .handle_empty_like ,
122
+ aten .constant_pad_nd .default : cls .handle_constant_pad_nd ,
122
123
}
123
124
124
125
if func in dispatcher :
@@ -279,6 +280,187 @@ def handle_new_empty(args, kwargs):
279
280
self_ls .local_offsets (),
280
281
)
281
282
283
+ @staticmethod
284
+ # pyre-fixme[3]: Return type must be annotated.
285
+ # pyre-fixme[2]: Parameter must be annotated.
286
+ def handle_constant_pad_nd (args , kwargs ):
287
+ """
288
+ Apply constant padding to LocalShardsWrapper.
289
+
290
+ The padding is based off of the following ideas:
291
+ - The resulting wrapper represents the padded version of the logical tensor.
292
+ - Each shard is padded based on the sharding type + dimension that is padded.
293
+ - For instance, CW shards padded on the left most col will have only padding on the first CW shard.
294
+ - Padding the top row will apply to all CW shards.
295
+ """
296
+ self_lsw = args [0 ]
297
+ pad_spec = args [1 ]
298
+ pad_value = args [2 ] if len (args ) > 2 else 0.0
299
+
300
+ if len (self_lsw .local_shards ()) == 0 :
301
+ raise NotImplementedError ("Padding empty LocalShardsWrapper is not supported." )
302
+
303
+ local_shards = self_lsw .local_shards ()
304
+
305
+ if len (local_shards ) == 1 :
306
+ padded_shard = torch .nn .functional .pad (
307
+ local_shards [0 ], pad_spec , mode = "constant" , value = pad_value
308
+ )
309
+ return LocalShardsWrapper ([padded_shard ], self_lsw .local_offsets ())
310
+
311
+ padded_shards = list (local_shards )
312
+
313
+ if local_shards [0 ].ndim == 2 :
314
+ # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
315
+ pad_left , pad_right , pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ], pad_spec [2 ], pad_spec [3 ]
316
+
317
+ if pad_top > 0 :
318
+ padded_shards = [
319
+ torch .nn .functional .pad (
320
+ shard , [0 , 0 , pad_top , 0 ], mode = "constant" , value = pad_value
321
+ )
322
+ for shard in padded_shards
323
+ ]
324
+ if pad_bottom > 0 :
325
+ padded_shards = [
326
+ torch .nn .functional .pad (
327
+ shard , [0 , 0 , 0 , pad_bottom ], mode = "constant" , value = pad_value
328
+ )
329
+ for shard in padded_shards
330
+ ]
331
+ if pad_left > 0 :
332
+ padded_shards [0 ] = torch .nn .functional .pad (
333
+ padded_shards [0 ],
334
+ [pad_left , 0 , 0 , 0 ],
335
+ mode = "constant" ,
336
+ value = pad_value
337
+ )
338
+ if pad_right > 0 :
339
+ padded_shards [- 1 ] = torch .nn .functional .pad (
340
+ padded_shards [- 1 ],
341
+ [0 , pad_right , 0 , 0 ],
342
+ mode = "constant" ,
343
+ value = pad_value
344
+ )
345
+ elif local_shards [0 ].ndim == 1 :
346
+ # 1D Row-wise sharding: [pad_top, pad_bottom]
347
+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
348
+
349
+ if pad_top > 0 :
350
+ padded_shards [0 ] = torch .nn .functional .pad (
351
+ padded_shards [0 ], [pad_top , 0 ], mode = "constant" , value = pad_value
352
+ )
353
+ if pad_bottom > 0 :
354
+ padded_shards [- 1 ] = torch .nn .functional .pad (
355
+ padded_shards [- 1 ], [0 , pad_bottom ], mode = "constant" , value = pad_value
356
+ )
357
+ else :
358
+ raise NotImplementedError (
359
+ f"Padding for { local_shards [0 ].ndim } D tensors is not supported. "
360
+ f"Only 1D and 2D tensors are currently supported."
361
+ )
362
+
363
+ # Update offsets and storage metadata
364
+ original_storage = self_lsw .storage_metadata ()
365
+ updated_offsets , updated_storage = LocalShardsWrapper ._compute_updated_metadata (
366
+ original_storage ,
367
+ self_lsw .local_offsets (),
368
+ pad_spec , local_shards [0 ].ndim ,
369
+ padded_shards
370
+ )
371
+
372
+ result = LocalShardsWrapper (padded_shards , updated_offsets )
373
+ result ._storage_meta = updated_storage
374
+ return result
375
+
376
+ @staticmethod
377
+ def _compute_updated_metadata (
378
+ original_storage : TensorStorageMetadata ,
379
+ original_offsets : list [torch .Size ],
380
+ pad_spec : list [int ],
381
+ ndim : int ,
382
+ padded_shards : list [torch .Tensor ],
383
+ ) -> tuple [list [torch .Size ], TensorStorageMetadata ]:
384
+ """
385
+ Compute updated offsets and storage metadata after padding is applied.
386
+
387
+ Args:
388
+ original_storage: Original storage metadata
389
+ original_offsets: Original shard offsets
390
+ pad_spec: Padding specification
391
+ ndim: Number of dimensions (1=RW or 2=CW)
392
+ padded_shards: Padded shard tensors
393
+
394
+ Returns:
395
+ Tuple of (updated_offsets, updated_storage_metadata)
396
+ """
397
+ if ndim == 1 : # 1D RW
398
+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
399
+
400
+ updated_offsets = []
401
+ for i , offset in enumerate (original_offsets ):
402
+ if i == 0 :
403
+ # First shard: offset stays the same (absorbs top padding)
404
+ updated_offsets .append (offset )
405
+ else :
406
+ # Subsequent shards: shift by top padding amount
407
+ new_offset = (offset [0 ] + pad_top ,)
408
+ updated_offsets .append (torch .Size (new_offset ))
409
+
410
+ new_global_size = torch .Size (
411
+ [original_storage .size [0 ] + pad_top + pad_bottom ]
412
+ )
413
+
414
+ elif ndim == 2 : # 2D CW
415
+ pad_left , pad_right , pad_top , pad_bottom = (
416
+ pad_spec [0 ],
417
+ pad_spec [1 ],
418
+ pad_spec [2 ],
419
+ pad_spec [3 ]
420
+ )
421
+
422
+ updated_offsets = []
423
+ for i , offset in enumerate (original_offsets ):
424
+ row_offset = offset [0 ]
425
+ col_offset = offset [1 ]
426
+
427
+ # Top/bottom padding doesn't affect offsets
428
+ # Left padding affects column offsets
429
+ if i == 0 :
430
+ # First shard: column offset stays the same (absorbs left padding)
431
+ new_offset = (row_offset , col_offset )
432
+ else :
433
+ # Subsequent shards: shift column offset by left padding amount
434
+ new_offset = (row_offset , col_offset + pad_left )
435
+
436
+ updated_offsets .append (torch .Size (new_offset ))
437
+
438
+ new_global_size = torch .Size (
439
+ [
440
+ original_storage .size [0 ] + pad_top + pad_bottom ,
441
+ original_storage .size [1 ] + pad_left + pad_right
442
+ ]
443
+ )
444
+
445
+ else :
446
+ raise NotImplementedError (f"Metadata computation for { ndim } D not supported" )
447
+
448
+ updated_chunks = [
449
+ ChunkStorageMetadata (
450
+ offsets = offset ,
451
+ sizes = shard .size (),
452
+ )
453
+ for offset , shard in zip (updated_offsets , padded_shards )
454
+ ]
455
+
456
+ updated_storage = TensorStorageMetadata (
457
+ properties = original_storage .properties ,
458
+ size = new_global_size ,
459
+ chunks = updated_chunks ,
460
+ )
461
+
462
+ return updated_offsets , updated_storage
463
+
282
464
@property
283
465
def device (self ) -> torch ._C .device : # type: ignore[override]
284
466
return (
0 commit comments