Skip to content

Commit 6bb69c1

Browse files
committed
Add dynamic=True to torch.compile call in nvfp4 packing
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent c2ffbac commit 6bb69c1

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def compress_weight(
7171
zero_point: Optional[torch.Tensor] = None,
7272
g_idx: Optional[torch.Tensor] = None,
7373
) -> Dict[str, torch.Tensor]:
74-
7574
quantized_weight = quantize(
7675
x=weight,
7776
scale=scale,
@@ -91,7 +90,6 @@ def decompress_weight(
9190
compressed_data: Dict[str, Tensor],
9291
quantization_args: Optional[QuantizationArgs] = None,
9392
) -> torch.Tensor:
94-
9593
weight = compressed_data["weight_packed"]
9694
scale = compressed_data["weight_scale"]
9795
global_scale = compressed_data["weight_global_scale"]
@@ -105,7 +103,7 @@ def decompress_weight(
105103
return decompressed_weight
106104

107105

108-
@torch.compile(fullgraph=True)
106+
@torch.compile(fullgraph=True, dynamic=True)
109107
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
110108
"""
111109
Packs a tensor with values in the fp4 range into uint8.
@@ -154,8 +152,9 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
154152
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
155153
)
156154

155+
157156
# reference: : https://github.com/vllm-project/vllm/pull/16362
158-
@torch.compile(fullgraph=True)
157+
@torch.compile(fullgraph=True, dynamic=True)
159158
def unpack_fp4_from_uint8(
160159
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
161160
) -> torch.Tensor:

0 commit comments

Comments
 (0)