@@ -73,7 +73,7 @@ def __new__(
73
73
cat_tensor_shape [1 ] += shard .size ()[1 ]
74
74
75
75
# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76
- if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # column -wise sharding
76
+ if len (local_shards ) > 1 and local_shards [0 ].ndim == 1 : # row -wise sharding
77
77
for shard in local_shards [1 :]:
78
78
cat_tensor_shape [0 ] += shard .size ()[0 ]
79
79
@@ -119,6 +119,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
119
119
aten .copy_ .default : cls .handle_copy_ ,
120
120
aten .zeros_like .default : cls .handle_zeros_like ,
121
121
aten .empty_like .default : cls .handle_empty_like ,
122
+ aten .constant_pad_nd .default : cls .handle_constant_pad_nd ,
122
123
}
123
124
124
125
if func in dispatcher :
@@ -279,6 +280,193 @@ def handle_new_empty(args, kwargs):
279
280
self_ls .local_offsets (),
280
281
)
281
282
283
+ @staticmethod
284
+ # pyre-fixme[3]: Return type must be annotated.
285
+ # pyre-fixme[2]: Parameter must be annotated.
286
+ def handle_constant_pad_nd (args , kwargs ):
287
+ """
288
+ Apply constant padding to LocalShardsWrapper.
289
+
290
+ The padding is based off of the following ideas:
291
+ - The resulting wrapper represents the padded version of the logical tensor.
292
+ - Each shard is padded based on the sharding type + dimension that is padded.
293
+ - For instance, CW shards padded on the left most col will have only padding on the first CW shard.
294
+ - Padding the top row will apply to all CW shards.
295
+ """
296
+ self_lsw = args [0 ]
297
+ pad_spec = args [1 ]
298
+ pad_value = args [2 ] if len (args ) > 2 else 0.0
299
+
300
+ if len (self_lsw .local_shards ()) == 0 :
301
+ raise Exception ("Cannot pad empty LocalShardsWrapper" )
302
+
303
+ local_shards = self_lsw .local_shards ()
304
+
305
+ if len (local_shards ) == 1 :
306
+ padded_shard = torch .nn .functional .pad (
307
+ local_shards [0 ], pad_spec , mode = "constant" , value = pad_value
308
+ )
309
+ return LocalShardsWrapper ([padded_shard ], self_lsw .local_offsets ())
310
+
311
+ padded_shards = list (local_shards )
312
+
313
+ if local_shards [0 ].ndim == 2 :
314
+ # 2D Column-wise sharding: [pad_left, pad_right, pad_top, pad_bottom]
315
+ pad_left , pad_right , pad_top , pad_bottom = (
316
+ pad_spec [0 ],
317
+ pad_spec [1 ],
318
+ pad_spec [2 ],
319
+ pad_spec [3 ],
320
+ )
321
+
322
+ if pad_top > 0 :
323
+ padded_shards = [
324
+ torch .nn .functional .pad (
325
+ shard , [0 , 0 , pad_top , 0 ], mode = "constant" , value = pad_value
326
+ )
327
+ for shard in padded_shards
328
+ ]
329
+ if pad_bottom > 0 :
330
+ padded_shards = [
331
+ torch .nn .functional .pad (
332
+ shard , [0 , 0 , 0 , pad_bottom ], mode = "constant" , value = pad_value
333
+ )
334
+ for shard in padded_shards
335
+ ]
336
+ if pad_left > 0 :
337
+ padded_shards [0 ] = torch .nn .functional .pad (
338
+ padded_shards [0 ],
339
+ [pad_left , 0 , 0 , 0 ],
340
+ mode = "constant" ,
341
+ value = pad_value ,
342
+ )
343
+ if pad_right > 0 :
344
+ padded_shards [- 1 ] = torch .nn .functional .pad (
345
+ padded_shards [- 1 ],
346
+ [0 , pad_right , 0 , 0 ],
347
+ mode = "constant" ,
348
+ value = pad_value ,
349
+ )
350
+ elif local_shards [0 ].ndim == 1 :
351
+ # 1D Row-wise sharding: [pad_top, pad_bottom]
352
+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
353
+
354
+ if pad_top > 0 :
355
+ padded_shards [0 ] = torch .nn .functional .pad (
356
+ padded_shards [0 ], [pad_top , 0 ], mode = "constant" , value = pad_value
357
+ )
358
+ if pad_bottom > 0 :
359
+ padded_shards [- 1 ] = torch .nn .functional .pad (
360
+ padded_shards [- 1 ], [0 , pad_bottom ], mode = "constant" , value = pad_value
361
+ )
362
+ else :
363
+ raise NotImplementedError (
364
+ f"Padding for { local_shards [0 ].ndim } D tensors is not supported. "
365
+ f"Only 1D and 2D tensors are currently supported."
366
+ )
367
+
368
+ # Update offsets and storage metadata
369
+ original_storage = self_lsw .storage_metadata ()
370
+ updated_offsets , updated_storage = LocalShardsWrapper ._compute_updated_metadata (
371
+ original_storage ,
372
+ self_lsw .local_offsets (),
373
+ pad_spec ,
374
+ local_shards [0 ].ndim ,
375
+ padded_shards ,
376
+ )
377
+
378
+ result = LocalShardsWrapper (padded_shards , updated_offsets )
379
+ result ._storage_meta = updated_storage
380
+ return result
381
+
382
+ @staticmethod
383
+ def _compute_updated_metadata (
384
+ original_storage : TensorStorageMetadata ,
385
+ original_offsets : list [torch .Size ],
386
+ pad_spec : list [int ],
387
+ ndim : int ,
388
+ padded_shards : list [torch .Tensor ],
389
+ ) -> tuple [list [torch .Size ], TensorStorageMetadata ]:
390
+ """
391
+ Compute updated offsets and storage metadata after padding is applied.
392
+
393
+ Args:
394
+ original_storage: Original storage metadata
395
+ original_offsets: Original shard offsets
396
+ pad_spec: Padding specification
397
+ ndim: Number of dimensions (1=RW or 2=CW)
398
+ padded_shards: Padded shard tensors
399
+
400
+ Returns:
401
+ Tuple of (updated_offsets, updated_storage_metadata)
402
+ """
403
+ if ndim == 1 : # 1D RW
404
+ pad_top , pad_bottom = pad_spec [0 ], pad_spec [1 ]
405
+
406
+ updated_offsets = []
407
+ for i , offset in enumerate (original_offsets ):
408
+ if i == 0 :
409
+ # First shard: offset stays the same (absorbs top padding)
410
+ updated_offsets .append (offset )
411
+ else :
412
+ # Subsequent shards: shift by top padding amount
413
+ new_offset = (offset [0 ] + pad_top ,)
414
+ updated_offsets .append (torch .Size (new_offset ))
415
+
416
+ new_global_size = torch .Size (
417
+ [original_storage .size [0 ] + pad_top + pad_bottom ]
418
+ )
419
+
420
+ elif ndim == 2 : # 2D CW
421
+ pad_left , pad_right , pad_top , pad_bottom = (
422
+ pad_spec [0 ],
423
+ pad_spec [1 ],
424
+ pad_spec [2 ],
425
+ pad_spec [3 ],
426
+ )
427
+
428
+ updated_offsets = []
429
+ for i , offset in enumerate (original_offsets ):
430
+ row_offset = offset [0 ]
431
+ col_offset = offset [1 ]
432
+
433
+ # Top/bottom padding doesn't affect offsets
434
+ # Left padding affects column offsets
435
+ if i == 0 :
436
+ # First shard: column offset stays the same (absorbs left padding)
437
+ new_offset = (row_offset , col_offset )
438
+ else :
439
+ # Subsequent shards: shift column offset by left padding amount
440
+ new_offset = (row_offset , col_offset + pad_left )
441
+
442
+ updated_offsets .append (torch .Size (new_offset ))
443
+
444
+ new_global_size = torch .Size (
445
+ [
446
+ original_storage .size [0 ] + pad_top + pad_bottom ,
447
+ original_storage .size [1 ] + pad_left + pad_right ,
448
+ ]
449
+ )
450
+
451
+ else :
452
+ raise NotImplementedError (f"Metadata computation for { ndim } D not supported" )
453
+
454
+ updated_chunks = [
455
+ ChunkStorageMetadata (
456
+ offsets = offset ,
457
+ sizes = shard .size (),
458
+ )
459
+ for offset , shard in zip (updated_offsets , padded_shards )
460
+ ]
461
+
462
+ updated_storage = TensorStorageMetadata (
463
+ properties = original_storage .properties ,
464
+ size = new_global_size ,
465
+ chunks = updated_chunks ,
466
+ )
467
+
468
+ return updated_offsets , updated_storage
469
+
282
470
@property
283
471
def device (self ) -> torch ._C .device : # type: ignore[override]
284
472
return (
0 commit comments