@@ -618,6 +618,10 @@ class FP8Linear(nn.Module):
618618
619619 This layer can hold FP8 weights with per-layer scale factors and
620620 automatically dequantizes during forward pass for compatibility.
621+
622+ The FP8 weight is stored as `fp8_weight_storage` to avoid nn.Module's
623+ special handling of 'weight'. Access via `.weight` property is provided
624+ for compatibility with LoRA loaders.
621625 """
622626
623627 def __init__ (
@@ -627,6 +631,7 @@ def __init__(
627631 bias : bool = True ,
628632 device : Optional [torch .device ] = None ,
629633 dtype : Optional [torch .dtype ] = None ,
634+ compute_dtype : Optional [torch .dtype ] = None ,
630635 ):
631636 """
632637 Initialize FP8 linear layer.
@@ -636,24 +641,49 @@ def __init__(
636641 out_features: Output feature dimension
637642 bias: Whether to include bias
638643 device: Target device
639- dtype: Compute dtype (for activations)
644+ dtype: Compute dtype (for activations) - deprecated, use compute_dtype
645+ compute_dtype: Compute dtype (for activations)
640646 """
641647 super ().__init__ ()
642648
643649 self .in_features = in_features
644650 self .out_features = out_features
645- self .compute_dtype = dtype or torch .bfloat16
651+ self .compute_dtype = compute_dtype or dtype or torch .bfloat16
646652
647- # Initialize as empty - will be filled during checkpoint loading
648- self .weight : Optional [QuantizedTensor ] = None
649- self .bias : Optional [ torch . Tensor ] = None
653+ # Use unique name to avoid nn.Module intercepting 'weight'
654+ self .fp8_weight_storage : Optional [QuantizedTensor ] = None
655+ self ._has_bias = bias
650656
651657 if bias :
652658 self .register_buffer (
653659 '_bias' ,
654660 torch .zeros (out_features , device = device , dtype = self .compute_dtype )
655661 )
656662
663+ @property
664+ def weight (self ) -> Optional [QuantizedTensor ]:
665+ """Get the FP8 quantized weight (for LoRA compatibility)."""
666+ return self .fp8_weight_storage
667+
668+ @weight .setter
669+ def weight (self , value : Optional [QuantizedTensor ]):
670+ """Set the FP8 quantized weight."""
671+ self .fp8_weight_storage = value
672+
673+ @property
674+ def bias (self ) -> Optional [torch .Tensor ]:
675+ """Get bias tensor."""
676+ return self ._bias if self ._has_bias and hasattr (self , '_bias' ) else None
677+
678+ @bias .setter
679+ def bias (self , value : Optional [torch .Tensor ]):
680+ """Set bias tensor."""
681+ if value is not None and self ._has_bias :
682+ if hasattr (self , '_bias' ):
683+ self ._bias .copy_ (value )
684+ else :
685+ self .register_buffer ('_bias' , value )
686+
657687 def set_fp8_weight (
658688 self ,
659689 fp8_weight : torch .Tensor ,
@@ -668,7 +698,7 @@ def set_fp8_weight(
668698 scale: Per-layer scale factor
669699 orig_dtype: Original dtype for dequantization
670700 """
671- self .weight = QuantizedTensor .from_fp8_with_scale (
701+ self .fp8_weight_storage = QuantizedTensor .from_fp8_with_scale (
672702 fp8_weight , scale , orig_dtype
673703 )
674704
@@ -682,13 +712,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
682712 Returns:
683713 Output tensor
684714 """
685- if self .weight is None :
715+ if self .fp8_weight_storage is None :
686716 raise RuntimeError ("Weight not set. Call set_fp8_weight first." )
687717
688718 # Dequantize weight for forward pass
689- weight = self .weight .dequantize ().to (x .dtype )
719+ weight = self .fp8_weight_storage .dequantize ().to (x .dtype )
690720
691- bias = self .bias if hasattr (self , '_bias' ) and self . _bias is not None else None
721+ bias = self ._bias if self . _has_bias and hasattr (self , '_bias' ) else None
692722 if bias is not None :
693723 bias = bias .to (x .dtype )
694724
@@ -748,169 +778,3 @@ def is_fp8_scaled_checkpoint(state_dict: Dict[str, torch.Tensor]) -> bool:
748778 return True
749779
750780 return has_fp8 and has_scale
751-
752-
753- class FP8Linear (nn .Module ):
754- """
755- Linear layer that stores weights in FP8 format and dequantizes on forward.
756-
757- This allows ~50% memory savings compared to BF16 while maintaining
758- numerical accuracy through per-tensor scaling.
759-
760- Args:
761- in_features: Input feature dimension
762- out_features: Output feature dimension
763- bias: Whether to include bias
764- device: Target device
765- compute_dtype: Dtype for computation (default: bfloat16)
766- """
767-
768- def __init__ (
769- self ,
770- in_features : int ,
771- out_features : int ,
772- bias : bool = True ,
773- device : Optional [torch .device ] = None ,
774- compute_dtype : torch .dtype = torch .bfloat16 ,
775- ):
776- super ().__init__ ()
777- self .in_features = in_features
778- self .out_features = out_features
779- self .compute_dtype = compute_dtype
780-
781- # FP8 weight storage
782- self .register_buffer (
783- 'weight_fp8' ,
784- torch .zeros (out_features , in_features , dtype = torch .float8_e4m3fn , device = device )
785- )
786- # Scale is always a scalar but store as 0-dim tensor
787- self .register_buffer (
788- 'weight_scale' ,
789- torch .tensor (1.0 , dtype = torch .float32 , device = device )
790- )
791-
792- # Optional bias in compute dtype
793- if bias :
794- self .bias = nn .Parameter (
795- torch .zeros (out_features , dtype = compute_dtype , device = device )
796- )
797- else :
798- self .register_parameter ('bias' , None )
799-
800- def set_fp8_weight (
801- self ,
802- weight_fp8 : torch .Tensor ,
803- scale : torch .Tensor ,
804- ) -> None :
805- """
806- Set the FP8 weight and scale.
807-
808- Args:
809- weight_fp8: Weight tensor in FP8 format
810- scale: Scale factor for dequantization (scalar)
811- """
812- self .weight_fp8 .copy_ (weight_fp8 )
813- # Handle scalar scale - extract the value if it's a tensor
814- if scale .numel () == 1 :
815- self .weight_scale .fill_ (scale .item ())
816- else :
817- self .weight_scale .fill_ (scale .flatten ()[0 ].item ())
818-
819- def forward (self , x : torch .Tensor ) -> torch .Tensor :
820- """
821- Forward pass with on-the-fly dequantization.
822-
823- Args:
824- x: Input tensor
825-
826- Returns:
827- Output tensor
828- """
829- # Dequantize weight: fp8 * scale -> compute_dtype
830- weight = self .weight_fp8 .to (self .compute_dtype ) * self .weight_scale .to (self .compute_dtype )
831-
832- # Convert input to compute dtype if needed
833- if x .dtype != self .compute_dtype :
834- x = x .to (self .compute_dtype )
835-
836- # Standard linear operation
837- return F .linear (x , weight , self .bias )
838-
839- @classmethod
840- def from_linear (
841- cls ,
842- linear : nn .Linear ,
843- device : Optional [torch .device ] = None ,
844- ) -> 'FP8Linear' :
845- """
846- Create FP8Linear from an existing nn.Linear layer.
847-
848- Quantizes weights to FP8 format with computed scale.
849-
850- Args:
851- linear: Source linear layer
852- device: Target device
853-
854- Returns:
855- FP8Linear layer with quantized weights
856- """
857- device = device or linear .weight .device
858- has_bias = linear .bias is not None
859- compute_dtype = linear .weight .dtype
860-
861- fp8_linear = cls (
862- linear .in_features ,
863- linear .out_features ,
864- bias = has_bias ,
865- device = device ,
866- compute_dtype = compute_dtype ,
867- )
868-
869- # Quantize weight
870- qdata , params = TensorCoreFP8Layout .quantize (
871- linear .weight .detach (),
872- scale = "recalculate" ,
873- )
874- fp8_linear .set_fp8_weight (qdata .to (device ), params ['scale' ].to (device ))
875-
876- # Copy bias
877- if has_bias :
878- fp8_linear .bias .data .copy_ (linear .bias .data )
879-
880- return fp8_linear
881-
882- def extra_repr (self ) -> str :
883- return f'in_features={ self .in_features } , out_features={ self .out_features } , bias={ self .bias is not None } '
884-
885-
886- def convert_linear_to_fp8 (
887- module : nn .Module ,
888- device : Optional [torch .device ] = None ,
889- skip_patterns : Optional [list ] = None ,
890- ) -> nn .Module :
891- """
892- Recursively convert all nn.Linear layers to FP8Linear.
893-
894- Args:
895- module: Module to convert
896- device: Target device
897- skip_patterns: List of name patterns to skip
898-
899- Returns:
900- Module with converted layers
901- """
902- skip_patterns = skip_patterns or []
903-
904- for name , child in list (module .named_children ()):
905- # Check skip patterns
906- should_skip = any (pattern in name for pattern in skip_patterns )
907-
908- if isinstance (child , nn .Linear ) and not should_skip :
909- # Convert to FP8
910- fp8_linear = FP8Linear .from_linear (child , device = device )
911- setattr (module , name , fp8_linear )
912- else :
913- # Recurse
914- convert_linear_to_fp8 (child , device = device , skip_patterns = skip_patterns )
915-
916- return module
0 commit comments