@@ -150,6 +150,7 @@ def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str], is_weig
150
150
"F64" : torch .float64 ,
151
151
"I64" : torch .int64 ,
152
152
"F8_E4M3" : torch .float8_e4m3fn ,
153
+ "F8_E5M2" : torch .float8_e5m2 ,
153
154
}
154
155
155
156
@@ -525,6 +526,43 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
525
526
return param
526
527
527
528
529
+ class ReduceFromModelParallelRegion (torch .autograd .Function ):
530
+ """
531
+ All-reduce in forward pass, identity in backward pass.
532
+ This is the `g` function in the paper: https://arxiv.org/abs/1909.08053
533
+ """
534
+
535
+ @staticmethod
536
+ def forward (ctx , x , device_mesh ):
537
+ if device_mesh .size () == 1 :
538
+ return x
539
+ dist .all_reduce (x , op = dist .ReduceOp .SUM , group = device_mesh .get_group ())
540
+ return x
541
+
542
+ @staticmethod
543
+ def backward (ctx , grad_output ):
544
+ return grad_output
545
+
546
+
547
+ class CopyToModelParallelRegion (torch .autograd .Function ):
548
+ """
549
+ Copy in forward pass, all-reduce in backward pass.
550
+ This is the `f` function in the paper: https://arxiv.org/abs/1909.08053
551
+ """
552
+
553
+ @staticmethod
554
+ def forward (ctx , x , device_mesh ):
555
+ ctx .device_mesh = device_mesh
556
+ return x
557
+
558
+ @staticmethod
559
+ def backward (ctx , grad_output ):
560
+ if ctx .device_mesh .size () == 1 :
561
+ return grad_output
562
+ dist .all_reduce (grad_output , op = dist .ReduceOp .SUM , group = ctx .device_mesh .get_group ())
563
+ return grad_output
564
+
565
+
528
566
class ColwiseParallel (TensorParallelLayer ):
529
567
"""
530
568
General tensor parallel layer for transformers.
@@ -547,15 +585,8 @@ def __init__(
547
585
548
586
@staticmethod
549
587
def _prepare_input_fn (input_layouts , desired_input_layouts , mod , inputs , device_mesh ):
550
- # TODO: figure out dynamo support for instance method and switch this to instance method
551
588
# annotate module input placements/sharding with input_layouts
552
589
input_tensor = inputs [0 ]
553
- if not isinstance (input_tensor , DTensor ):
554
- input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
555
-
556
- # transform the input layouts to the desired layouts of ColwiseParallel
557
- if input_layouts != desired_input_layouts :
558
- input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = False )
559
590
return input_tensor
560
591
561
592
def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
@@ -564,41 +595,19 @@ def partition_tensor(self, param, empty_param, param_type, param_casting_dtype,
564
595
# weight would become Shard(1)
565
596
if param_type == "bias" :
566
597
parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
567
- shard = [Shard (- 1 )]
568
598
else :
569
- shard = [Shard (- 2 )]
570
599
parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 2 )
571
600
572
601
parameter = parameter .to (param_casting_dtype )
573
602
if to_contiguous :
574
603
parameter = parameter .contiguous ()
575
- if self .use_dtensor :
576
- parameter = DTensor .from_local (
577
- parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
578
- )
604
+
579
605
return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
580
606
581
607
@staticmethod
582
608
def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
583
- # outputs is a shard on last dimension DTensor, i.e. Shard(-1)
584
- if outputs .placements != output_layouts :
585
- outputs = outputs .redistribute (placements = output_layouts , async_op = False )
586
- # back to local tensor
587
- return outputs .to_local () if use_local_output and isinstance (outputs , DTensor ) else outputs
588
-
589
-
590
- class PackedColwiseParallel (ColwiseParallel ):
591
- def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
592
- # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
593
- # means Colwise as Linear is input * weight^T + bias, where
594
- # weight would become Shard(1)
595
- parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
596
- parameter = parameter .to (param_casting_dtype )
597
- if to_contiguous :
598
- parameter = parameter .contiguous ()
599
- if self .use_dtensor :
600
- parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
601
- return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
609
+ outputs = CopyToModelParallelRegion .apply (outputs , device_mesh )
610
+ return outputs
602
611
603
612
604
613
class RowwiseParallel (TensorParallelLayer ):
@@ -635,23 +644,15 @@ def __init__(
635
644
self .use_dtensor = use_dtensor
636
645
637
646
def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
638
- # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
639
- # means Rowwise as nn.Linear is input * weight^T + bias, where
640
- # weight would become Shard(0)
641
- if param_type != "bias" :
642
- parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
643
- shard = [Shard (- 1 )]
644
- else :
645
- shard = [Replicate ()]
647
+ if param_type == "bias" :
646
648
parameter = param [:]
649
+ else :
650
+ parameter = get_tensor_shard (param , empty_param , device_mesh , rank , - 1 )
647
651
648
652
parameter = parameter .to (param_casting_dtype )
649
653
if to_contiguous :
650
654
parameter = parameter .contiguous ()
651
- if self .use_dtensor :
652
- parameter = DTensor .from_local (
653
- parameter , device_mesh , shard , run_check = False , shape = empty_param .size (), stride = empty_param .stride ()
654
- )
655
+
655
656
return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
656
657
657
658
@staticmethod
@@ -661,24 +662,14 @@ def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_
661
662
mod .bias = None
662
663
663
664
input_tensor = inputs [0 ]
664
- if not isinstance (input_tensor , DTensor ):
665
- input_tensor = DTensor .from_local (input_tensor , device_mesh , input_layouts , run_check = False )
666
-
667
- if input_layouts != desired_input_layouts :
668
- input_tensor = input_tensor .redistribute (placements = desired_input_layouts , async_op = True )
669
665
return input_tensor
670
666
671
667
@staticmethod
672
668
def _prepare_output_fn (output_layouts , use_local_output , mod , outputs , device_mesh ):
673
- # Rowwise sharding produces partial output, depending on output layouts:
674
- # 1. to replicate -> allreduce
675
- # 2. to shard -> reduce_scatter
676
- if outputs .placements != output_layouts :
677
- outputs = outputs .redistribute (placements = output_layouts , async_op = True )
669
+ outputs = ReduceFromModelParallelRegion .apply (outputs , device_mesh )
678
670
if hasattr (mod , "_bias" ):
679
671
outputs += mod ._bias
680
- # back to local tensor if use_local_output is True
681
- return outputs .to_local () if use_local_output and isinstance (outputs , DTensor ) else outputs
672
+ return outputs
682
673
683
674
def prepare_module_tp (self , module : nn .Module , device_mesh ) -> nn .Module :
684
675
module ._distribute_module_applied = True
@@ -703,6 +694,21 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
703
694
)
704
695
705
696
697
+ class PackedColwiseParallel (ColwiseParallel ):
698
+ def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
699
+ # NOTE(3outeille): need to be deprecated as no longer using dtensors
700
+ # colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
701
+ # means Colwise as Linear is input * weight^T + bias, where
702
+ # weight would become Shard(1)
703
+ parameter = get_packed_weights (param , empty_param , device_mesh , rank , - 2 )
704
+ parameter = parameter .to (param_casting_dtype )
705
+ if to_contiguous :
706
+ parameter = parameter .contiguous ()
707
+ if self .use_dtensor :
708
+ parameter = DTensor .from_local (parameter , device_mesh , [Shard (- 2 )], run_check = False )
709
+ return nn .Parameter (parameter , requires_grad = parameter .is_floating_point ())
710
+
711
+
706
712
class PackedRowwiseParallel (RowwiseParallel ):
707
713
def partition_tensor (self , param , empty_param , param_type , param_casting_dtype , to_contiguous , rank , device_mesh ):
708
714
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
0 commit comments