Skip to content

Conversation

Aminsed
Copy link

@Aminsed Aminsed commented Sep 27, 2025

Add NCHW BatchNorm forward (two-pass; fp32 accum). API: triton_kernels.batchnorm_forward(...)(y, mean, var).

  • Tests: 21 pass vs PyTorch across 2D/4D, fp32/fp16/bf16 (tols fp32 1e-5/1e-6; half 3e-2/3e-3).
  • Perf (RTX A6000): e.g., (64,128,32,32) fp16 train 0.212 ms vs 0.290 ms (~1.37×). Script: python/triton_kernels/bench/bench_batchnorm.py.
  • Limits: NCHW-only; no running-stat updates; no backward/fused variants.

Closes #900.

…\n\n- Two-pass kernels: stats (sum/sumsq) + normalize\n- Dtypes: fp32/fp16/bf16; training/eval; PyTorch parity\n- Tests: 21 cases across shapes/dtypes/eps\n- Re-export in triton_kernels.__init__\n\nCloses: triton-lang#900
@Aminsed Aminsed requested a review from ptillet as a code owner September 27, 2025 04:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement BatchNorm in triton
1 participant