12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import math
16
- from typing import Optional , Union
15
+ from typing import Optional
17
16
18
17
import torch
19
18
from compressed_tensors .transform import TransformArgs , TransformScheme
26
25
from compressed_tensors .utils import get_execution_device , get_offloaded_device
27
26
from compressed_tensors .utils .helpers import ParameterizedDefaultDict
28
27
from torch import Tensor , device , dtype
29
- from torch .nn import Linear , Module , Parameter
28
+ from torch .nn import Module , Parameter
30
29
31
30
32
31
@TransformFactory .register ("hadamard" )
@@ -54,14 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
54
53
"""
55
54
assert hasattr (module , "weight" )
56
55
size = get_transform_size (module , args .location , self .scheme .head_dim )
57
- dtype = module . weight . dtype
56
+ dtype = self . scheme . precision
58
57
device = get_offloaded_device (module )
59
58
exec_device = get_execution_device (module )
60
59
61
60
factory_kwargs = {"construct_device" : exec_device }
62
61
weight = self .weights .get (size , dtype , device , factory_kwargs = factory_kwargs )
63
62
perm = self .perms [weight ] if self .scheme .randomize else None
64
- return HadamardTransform (weight , perm , args , type (module ))
63
+ return HadamardTransform (weight , perm , self . scheme , args , type (module ))
65
64
66
65
def _create_weight (
67
66
self ,
@@ -85,15 +84,18 @@ def __init__(
85
84
self ,
86
85
weight : Parameter ,
87
86
perm : Optional [Parameter ],
87
+ scheme : TransformScheme ,
88
88
args : TransformArgs ,
89
89
module_type : type [torch .nn .Module ],
90
90
):
91
91
super ().__init__ ()
92
92
self .weight = weight
93
93
self .perm = perm
94
+ self .scheme = scheme
94
95
self .args = args
95
96
self .module_type = module_type
96
- self ._scale = math .sqrt (weight .size (0 ))
97
+ self ._scale = torch .tensor (weight .size (0 ), dtype = self .scheme .precision ).sqrt ()
98
+ self ._precision = scheme .precision if args .is_online () else torch .float64
97
99
98
100
def forward (self , value : Tensor ) -> Tensor :
99
101
weight = self .weight
@@ -105,6 +107,11 @@ def forward(self, value: Tensor) -> Tensor:
105
107
weight = weight .T
106
108
107
109
return (
108
- apply_transform_weight (weight , value , self .args .location , self .module_type )
110
+ apply_transform_weight (
111
+ weight .to (self ._precision ),
112
+ value .to (self ._precision ),
113
+ self .args .location ,
114
+ self .module_type ,
115
+ )
109
116
/ self ._scale
110
- )
117
+ ). to ( value . dtype )
0 commit comments