@@ -73,7 +73,7 @@ def __new__(
73
73
cat_tensor_shape [1 ] += shard .size ()[1 ]
74
74
75
75
# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76
- if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # column -wise sharding
76
+ if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # row -wise sharding
77
77
for shard in local_shards [1 :]:
78
78
cat_tensor_shape [0 ] += shard .size ()[0 ]
79
79
@@ -119,6 +119,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119
119
aten .copy_ .default : cls .handle_copy_ ,
120
120
aten .zeros_like .default : cls .handle_zeros_like ,
121
121
aten .empty_like .default : cls .handle_empty_like ,
122
+ aten .constant_pad_nd .default : cls .handle_constant_pad_nd ,
122
123
}
123
124
124
125
if func in dispatcher :
@@ -279,6 +280,208 @@ def handle_new_empty(args, kwargs):
279
280
self_ls .local_offsets (),
280
281
)
281
282
283
+ @staticmethod
284
+ # pyre-fixme[3]: Return type must be annotated.
285
+ # pyre-fixme[2]: Parameter must be annotated.
286
+ def handle_constant_pad_nd (args , kwargs ):
287
+ """
288
+ Apply constant padding to LocalShardsWrapper.
289
+
290
+ The padding is based off of the following ideas:
291
+ - The resulting wrapper represents the padded version of the logical tensor.
292
+ - Each shard is padded based on the sharding type + dimension that is padded.
293
+ - For instance, CW shards padded on the left most col will have only padding on the first CW shard.
294
+ - Padding the top row will apply to all CW shards.
295
+ """
296
+ self_lsw = args [0 ]
297
+ pad_spec = args [1 ]
298
+ pad_value = args [2 ] if len (args ) > 2 else 0.0
299
+
300
+ if len (self_lsw .local_shards ()) == 0 :
301
+ raise NotImplementedError (
302
+ "Padding empty LocalShardsWrapper is not supported."
303
+ )
304
+
305
+ local_shards = self_lsw .local_shards ()
306
+
307
+ if len (local_shards ) == 1 :
308
+ padded_shard = torch .nn .functional .pad (
309
+ local_shards [0 ], pad_spec , mode = "constant" , value = pad_value
310
+ )
311
+ return LocalShardsWrapper ([padded_shard ], self_lsw .local_offsets ())
312
+
313
+ padded_shards = list (local_shards )
314
+
315
+ if local_shards [0 ].ndim == 2 :
316
+ # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
317
+ if len (pad_spec ) == 2 :
318
+ # Single dimension padding happens on the left most column
319
+ pad_spec = pad_spec + [0 , 0 ]
320
+
321
+ if len (pad_spec ) != 4 :
322
+ raise ValueError (
323
+ f"Padding spec must be of length 4 for 2D tensors, got { len (pad_spec )} "
324
+ )
325
+
326
+ pad_left , pad_right , pad_top , pad_bottom = (
327
+ pad_spec [0 ],
328
+ pad_spec [1 ],
329
+ pad_spec [2 ],
330
+ pad_spec [3 ],
331
+ )
332
+
333
+ if pad_top > 0 :
334
+ padded_shards = [
335
+ torch .nn .functional .pad (
336
+ shard , [0 , 0 , pad_top , 0 ], mode = "constant" , value = pad_value
337
+ )
338
+ for shard in padded_shards
339
+ ]
340
+ if pad_bottom > 0 :
341
+ padded_shards = [
342
+ torch .nn .functional .pad (
343
+ shard , [0 , 0 , 0 , pad_bottom ], mode = "constant" , value = pad_value
344
+ )
345
+ for shard in padded_shards
346
+ ]
347
+ if pad_left > 0 :
348
+ padded_shards [0 ] = torch .nn .functional .pad (
349
+ padded_shards [0 ],
350
+ [pad_left , 0 , 0 , 0 ],
351
+ mode = "constant" ,
352
+ value = pad_value ,
353
+ )
354
+ if pad_right > 0 :
355
+ padded_shards [- 1 ] = torch .nn .functional .pad (
356
+ padded_shards [- 1 ],
357
+ [0 , pad_right , 0 , 0 ],
358
+ mode = "constant" ,
359
+ value = pad_value ,
360
+ )
361
+ elif local_shards [0 ].ndim == 1 :
362
+ # 1D Row-wise sharding: [pad_top, pad_bottom]
363
+ if len (pad_spec ) != 2 :
364
+ raise ValueError (
365
+ f"Padding spec must be of length 2 for 1D tensors, got { len (pad_spec )} "
366
+ )
367
+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
368
+
369
+ if pad_top > 0 :
370
+ padded_shards [0 ] = torch .nn .functional .pad (
371
+ padded_shards [0 ], [pad_top , 0 ], mode = "constant" , value = pad_value
372
+ )
373
+ if pad_bottom > 0 :
374
+ padded_shards [- 1 ] = torch .nn .functional .pad (
375
+ padded_shards [- 1 ], [0 , pad_bottom ], mode = "constant" , value = pad_value
376
+ )
377
+ else :
378
+ raise NotImplementedError (
379
+ f"Padding for { local_shards [0 ].ndim } D tensors is not supported. "
380
+ f"Only 1D and 2D tensors are currently supported."
381
+ )
382
+
383
+ # Update offsets and storage metadata
384
+ original_storage = self_lsw .storage_metadata ()
385
+ updated_offsets , updated_storage = LocalShardsWrapper ._compute_updated_metadata (
386
+ original_storage ,
387
+ self_lsw .local_offsets (),
388
+ pad_spec ,
389
+ local_shards [0 ].ndim ,
390
+ padded_shards ,
391
+ )
392
+
393
+ result = LocalShardsWrapper (padded_shards , updated_offsets )
394
+ result ._storage_meta = updated_storage
395
+ return result
396
+
397
+ @staticmethod
398
+ def _compute_updated_metadata (
399
+ original_storage : TensorStorageMetadata ,
400
+ original_offsets : list [torch .Size ],
401
+ pad_spec : list [int ],
402
+ ndim : int ,
403
+ padded_shards : list [torch .Tensor ],
404
+ ) -> tuple [list [tuple [int , ...]], TensorStorageMetadata ]:
405
+ """
406
+ Compute updated offsets and storage metadata after padding is applied.
407
+
408
+ Args:
409
+ original_storage: Original storage metadata
410
+ original_offsets: Original shard offsets
411
+ pad_spec: Padding specification
412
+ ndim: Number of dimensions (1=RW or 2=CW)
413
+ padded_shards: Padded shard tensors
414
+
415
+ Returns:
416
+ Tuple of (updated_offsets, updated_storage_metadata)
417
+ """
418
+ if ndim == 1 : # 1D RW
419
+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
420
+
421
+ updated_offsets = []
422
+ for i , offset in enumerate (original_offsets ):
423
+ if i == 0 :
424
+ # First shard: offset stays the same (absorbs top padding)
425
+ updated_offsets .append (tuple (offset ))
426
+ else :
427
+ # Subsequent shards: shift by top padding amount
428
+ new_offset = (offset [0 ] + pad_top ,)
429
+ updated_offsets .append (new_offset )
430
+
431
+ new_global_size = torch .Size (
432
+ [original_storage .size [0 ] + pad_top + pad_bottom ]
433
+ )
434
+
435
+ elif ndim == 2 : # 2D CW
436
+ pad_left , pad_right , pad_top , pad_bottom = (
437
+ pad_spec [0 ],
438
+ pad_spec [1 ],
439
+ pad_spec [2 ],
440
+ pad_spec [3 ],
441
+ )
442
+
443
+ updated_offsets = []
444
+ for i , offset in enumerate (original_offsets ):
445
+ row_offset = offset [0 ]
446
+ col_offset = offset [1 ]
447
+
448
+ # Top/bottom padding doesn't affect offsets
449
+ # Left padding affects column offsets
450
+ if i == 0 :
451
+ # First shard: column offset stays the same (absorbs left padding)
452
+ new_2d_offset = (row_offset , col_offset )
453
+ else :
454
+ # Subsequent shards: shift column offset by left padding amount
455
+ new_2d_offset = (row_offset , col_offset + pad_left )
456
+
457
+ updated_offsets .append (new_2d_offset )
458
+
459
+ new_global_size = torch .Size (
460
+ [
461
+ original_storage .size [0 ] + pad_top + pad_bottom ,
462
+ original_storage .size [1 ] + pad_left + pad_right ,
463
+ ]
464
+ )
465
+
466
+ else :
467
+ raise NotImplementedError (f"Metadata computation for { ndim } D not supported" )
468
+
469
+ updated_chunks = [
470
+ ChunkStorageMetadata (
471
+ offsets = torch .Size (offset ),
472
+ sizes = shard .size (),
473
+ )
474
+ for offset , shard in zip (updated_offsets , padded_shards )
475
+ ]
476
+
477
+ updated_storage = TensorStorageMetadata (
478
+ properties = original_storage .properties ,
479
+ size = new_global_size ,
480
+ chunks = updated_chunks ,
481
+ )
482
+
483
+ return updated_offsets , updated_storage
484
+
282
485
@property
283
486
def device (self ) -> torch ._C .device : # type: ignore[override]
284
487
return (
0 commit comments