@@ -71,7 +71,6 @@ def compress_weight(
71
71
zero_point : Optional [torch .Tensor ] = None ,
72
72
g_idx : Optional [torch .Tensor ] = None ,
73
73
) -> Dict [str , torch .Tensor ]:
74
-
75
74
quantized_weight = quantize (
76
75
x = weight ,
77
76
scale = scale ,
@@ -91,7 +90,6 @@ def decompress_weight(
91
90
compressed_data : Dict [str , Tensor ],
92
91
quantization_args : Optional [QuantizationArgs ] = None ,
93
92
) -> torch .Tensor :
94
-
95
93
weight = compressed_data ["weight_packed" ]
96
94
scale = compressed_data ["weight_scale" ]
97
95
global_scale = compressed_data ["weight_global_scale" ]
@@ -105,7 +103,7 @@ def decompress_weight(
105
103
return decompressed_weight
106
104
107
105
108
- @torch .compile (fullgraph = True )
106
+ @torch .compile (fullgraph = True , dynamic = True )
109
107
def pack_fp4_to_uint8 (x : torch .Tensor ) -> torch .Tensor :
110
108
"""
111
109
Packs a tensor with values in the fp4 range into uint8.
@@ -154,8 +152,9 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
154
152
[0.0 , 0.5 , 1.0 , 1.5 , 2.0 , 3.0 , 4.0 , 6.0 ], dtype = torch .float32
155
153
)
156
154
155
+
157
156
# reference: : https://github.com/vllm-project/vllm/pull/16362
158
- @torch .compile (fullgraph = True )
157
+ @torch .compile (fullgraph = True , dynamic = True )
159
158
def unpack_fp4_from_uint8 (
160
159
a : torch .Tensor , m : int , n : int , dtype : Optional [torch .dtype ] = torch .bfloat16
161
160
) -> torch .Tensor :
0 commit comments