Skip to content

Commit ba85784

Browse files
apply transform weights in float64 precision
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent b082929 commit ba85784

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ def forward(self, value: Tensor) -> Tensor:
105105
weight = weight.T
106106

107107
return (
108-
apply_transform_weight(weight, value, self.args.location, self.module_type)
108+
apply_transform_weight(
109+
weight.to(torch.float64),
110+
value.to(torch.float64),
111+
self.args.location,
112+
self.module_type,
113+
)
109114
/ self._scale
110-
)
115+
).to(weight.dtype)

0 commit comments

Comments
 (0)