Skip to content

Conversation

@crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 11, 2025

What does this PR do?

As per title, this PR enables NVFP4 in benchmark_inference.py using NVFuser's nvfp4 kernels.


on GB200, pjnl-20251113

nvfp4

$ NVFUSER_ENABLE="id_model(all)" python thunder/benchmarks/benchmark_inference.py --output-length 2 --enable-nvfp4 --mode thunder
...
============================================================
BENCHMARK RESULTS - meta-llama/Llama-4-Maverick-17B-128E thunder
============================================================

Throughput Metrics:
  Overall Throughput: 114.00 tokens/sec
  Prefill Throughput: 211098.56 tokens/sec
  Decode Throughput: 128.33 tokens/sec
  Latency: 10.03 ms/token

Latency Breakdown:
  Time to First Token (TTFT): 12.23 ms
  Time Between Output Tokens (TBOT): 7.82 ms
  Prefill Time: 12.23 ms
  Decode Time: 7.82 ms
  Total Generation Time: 20.05 ms

Memory Usage:
  Current Memory: 14.23 GB
  Peak Memory: 15.19 GB

Variance Analysis:
  Throughput Std Dev: 26.42 ms
  TTFT Std Dev: 26.08 ms

bf16

$ python thunder/benchmarks/benchmark_inference.py --output-length 2 --mode thunder
...
============================================================
BENCHMARK RESULTS - meta-llama/Llama-4-Maverick-17B-128E thunder
============================================================

Throughput Metrics:
  Overall Throughput: 120.80 tokens/sec
  Prefill Throughput: 220544.66 tokens/sec
  Decode Throughput: 137.55 tokens/sec
  Latency: 8.28 ms/token

Latency Breakdown:
  Time to First Token (TTFT): 9.29 ms
  Time Between Output Tokens (TBOT): 7.27 ms
  Prefill Time: 9.29 ms
  Decode Time: 7.27 ms
  Total Generation Time: 16.56 ms

Memory Usage:
  Current Memory: 37.39 GB
  Peak Memory: 38.34 GB

Variance Analysis:
  Throughput Std Dev: 0.19 ms
  TTFT Std Dev: 0.12 ms

cc: @IvanYashchuk

@crcrpar crcrpar requested a review from jjsjann123 November 11, 2025 13:15
@crcrpar crcrpar changed the base branch from crpa/try-nvfuer5230 to main November 14, 2025 07:40
@crcrpar crcrpar force-pushed the still-nvfp4-run-failing branch from 2652cb1 to 0cff702 Compare November 14, 2025 07:41
@crcrpar crcrpar changed the title [nvfp4 benchmark_inference] Let TorchDynamo work w/o errors [benchmark_inference] Enable NVFP4 with NVFuser's NVFP4 kernels Nov 14, 2025
Comment on lines +757 to +744
parser.add_argument(
"--enable-nvfp4",
action="store_true",
help="Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)",
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to requite NVFUSER_ENABLE="id_model(all)" at the moment. We might want to set the env var when this option is set.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linking nvfuser issue: NVIDIA/Fuser#5200

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are going to merge this with the nvfuser benchmark. Let's merge it as-is and we can follow up with cleanup in the written out model.

parser.add_argument(
"--quantize-linear",
action="store_true",
help="[Experimental] Quantize nn.Linear to NVFP4. Note: nvfuser has not yet implemented nvfp4_matmul translator",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting a hang with --quantize-linear

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

dtype=activation.dtype,
)
for i in range(fp4_weight.size(0)):
# NOTE: dequantize here doesn't look right, since we have (g, k, n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is not used since we have registered translation rule for this op in nvfuser. So I don't think we have to bother fixing it for now.

@jjsjann123
Copy link
Collaborator

tagging @tbqh

crcrpar and others added 12 commits November 17, 2025 01:34
Signed-off-by: Masaki Kozuki <[email protected]>
…ns in inference benchmark. Enhance `_quantize_llama4` to conditionally quantize linear layers. Update command-line arguments for NVFP4 registration and quantization control. Adjust custom operations to ensure correct tensor shapes and handling.
… Update `_quantize_llama4` to simplify linear layer quantization handling. Modify command-line arguments for NVFP4 to clarify usage and remove deprecated options. Add warnings for experimental features and ensure proper registration of custom ops.
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
@crcrpar crcrpar force-pushed the still-nvfp4-run-failing branch from 26a9de1 to 58482e8 Compare November 17, 2025 09:34
# This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs
out_features = fp4_weight.size(2)
output_shape = activation.shape[:-1] + (out_features,)
return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this function also verify that weight, activation and other relevant tensors are on the same device?


new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w.transpose(-1, -2))
new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w.transpose(-1, -2))
new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would revert the changes from #2659 leading to perf regression for BF16 grouped_mm path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a ton for pointing that out!

It's probably a good idea to have a better separation between bf16 and fp4 code path. But I could at least put this inside a conditional guarded by dtype.

scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors)

return fp4_weight, scale_factors, global_scales, ab_strides, c_strides
return fp4_weight.transpose(-1, -2), scale_factors, global_scales
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to transpose just fp4_weight but not scale_factors as the scale_factors were calculated before the transpose (or maybe the downstream code accounts for this)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The reason is that access through those pointers are computed by the kernel by hand. So stride here is really just used for validation.

The requirement is that both weight and scale factor would have k dimension as the fastest, which is what the quantization function produces.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants