2424from compressed_tensors .utils import get_offloaded_device
2525from compressed_tensors .utils .helpers import ParameterizedDefaultDict
2626from torch import Tensor , device , dtype
27- from torch .nn import Linear , Module , Parameter
27+ from torch .nn import Module , Parameter
2828
2929
3030@TransformFactory .register ("random-matrix" )
@@ -52,14 +52,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5252 """
5353 assert hasattr (module , "weight" )
5454 size = get_transform_size (module , args .location , self .scheme .head_dim )
55- dtype = module . weight . dtype
55+ dtype = self . scheme . precision
5656 device = get_offloaded_device (module )
5757
5858 weight = self .weights [size , dtype , device ]
5959 if args .inverse :
6060 weight = self .inverses [weight ]
6161
62- return RandomMatrixTransform (weight , args , type (module ))
62+ return RandomMatrixTransform (weight , self . scheme , args , type (module ))
6363
6464 def _create_weight (self , size : int , dtype : dtype , device : device ) -> Parameter :
6565 # TODO: verify that weight is invertible (has non-zero determinant)
@@ -78,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
7878 def __init__ (
7979 self ,
8080 weight : Tensor ,
81+ scheme : TransformScheme ,
8182 args : TransformArgs ,
8283 module_type : type [torch .nn .Module ],
8384 ):
8485 super ().__init__ ()
8586 self .weight = weight # is an inverse if args.inverse
87+ self .scheme = scheme
8688 self .args = args
8789 self .module_type = module_type
90+ self ._precision = scheme .precision if args .is_online () else torch .float64
8891
8992 def forward (self , value : Tensor ) -> Parameter :
9093 return apply_transform_weight (
91- self .weight , value , self .args .location , self .module_type
92- )
94+ self .weight .to (self ._precision ),
95+ value .to (self ._precision ),
96+ self .args .location ,
97+ self .module_type ,
98+ ).to (value .dtype )
9399
94100 def right_inverse (self , value : Tensor ) -> Tensor :
95101 inverse = high_precision_invert (self .weight )
96102 return apply_transform_weight (
97- inverse , value , self .args .location , self .module_type
98- )
103+ inverse .to (self ._precision ),
104+ value .to (self ._precision ),
105+ self .args .location ,
106+ self .module_type ,
107+ ).to (value .dtype )
99108
100109
101110def high_precision_invert (weight : Tensor ) -> Tensor :
102- return torch .linalg .inv (weight .to (torch .float32 )).to (weight .dtype )
111+ return torch .linalg .inv (weight .to (torch .float64 )).to (weight .dtype )
0 commit comments