Skip to content

Commit 245de8e

Browse files
committed
add torch.nn.functional.scaled_grouped_mm. needs testing
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent f8648aa commit 245de8e

File tree

4 files changed

+676
-0
lines changed

4 files changed

+676
-0
lines changed

thunder/core/prims.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ class PrimIDs(Enum):
270270
# Linear algebra prims (Mostly experimental)
271271
MATMUL = auto()
272272
_GROUPED_MM = auto() # Used for grouped matmuls
273+
SCALED_GROUPED_MM = auto() # Used for scaled grouped matmuls
273274
# NN prims (Experimental!)
274275
CONVOLUTION = auto()
275276
EMBEDDING = auto()
@@ -3792,6 +3793,128 @@ def _grouped_mm_meta(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> Te
37923793
)
37933794

37943795

3796+
def scaled_grouped_mm_meta(
3797+
a: TensorProxy,
3798+
b: TensorProxy,
3799+
scale_a: TensorProxy,
3800+
scale_b: TensorProxy,
3801+
offsets: None | TensorProxy = None,
3802+
bias: None | TensorProxy = None,
3803+
scale_result: None | TensorProxy = None,
3804+
out_dtype: None | dtypes.dtype = None,
3805+
) -> TensorProxy:
3806+
"""Meta function for scaled_grouped_mm primitive.
3807+
3808+
Similar to _grouped_mm but with scale tensors for quantization/dequantization.
3809+
Accepts the following shape combinations:
3810+
1. (m, k) x (k, n) -> (groups, m, n)
3811+
2. (groups, m, k) x (k, n) -> (m, n)
3812+
3. (m, k) x (groups, k, n) -> (m, n)
3813+
3814+
Args:
3815+
a: Input tensor of shape (groups, m, k) or (m, k)
3816+
b: Input tensor of shape (groups, k, n) or (k, n)
3817+
scale_a: Scale tensor for a
3818+
scale_b: Scale tensor for b
3819+
offsets: Optional offset tensor of shape (groups,)
3820+
bias: Optional bias tensor
3821+
scale_result: Optional scale tensor for result
3822+
out_dtype: Optional output dtype
3823+
3824+
Returns:
3825+
TensorProxy with shape (groups, m, n) or (m, n)
3826+
"""
3827+
# Validate types
3828+
utils.check_type(a, TensorProxy)
3829+
utils.check_type(b, TensorProxy)
3830+
utils.check_type(scale_a, TensorProxy)
3831+
utils.check_type(scale_b, TensorProxy)
3832+
3833+
# Accept 2D or 3D tensors
3834+
utils.check(a.ndim in (2, 3), lambda: f"Expected a to have 2 or 3 dimensions, got {a.ndim}")
3835+
utils.check(b.ndim in (2, 3), lambda: f"Expected b to have 2 or 3 dimensions, got {b.ndim}")
3836+
3837+
# Compute output shape using same logic as _grouped_mm
3838+
if offsets is not None:
3839+
utils.check_type(offsets, TensorProxy)
3840+
utils.check(offsets.ndim == 1, lambda: f"`offsets` must be a vector, got shape {offsets.shape}")
3841+
3842+
if a.ndim == 2 and b.ndim == 2:
3843+
utils.check(a.shape[1] == b.shape[0], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
3844+
out_shape = (offsets.shape[0], a.shape[0], b.shape[1])
3845+
elif a.ndim == 3 and b.ndim == 2:
3846+
utils.check(a.shape[2] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
3847+
utils.check(a.shape[0] == offsets.shape[0], lambda: f"Group count mismatch: {a.shape} vs {offsets.shape}")
3848+
out_shape = (a.shape[1], b.shape[1])
3849+
elif a.ndim == 2 and b.ndim == 3:
3850+
utils.check(a.shape[1] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
3851+
utils.check(b.shape[0] == offsets.shape[0], lambda: f"Group count mismatch: {b.shape} vs {offsets.shape}")
3852+
out_shape = (a.shape[0], b.shape[2])
3853+
else:
3854+
utils.check(False, lambda: f"Unexpected shape combination: {a.shape} and {b.shape}")
3855+
else:
3856+
# Without offsets, fall back to standard matmul shape logic
3857+
if a.ndim == 2 and b.ndim == 2:
3858+
utils.check(a.shape[1] == b.shape[0], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
3859+
out_shape = (a.shape[0], b.shape[1])
3860+
elif a.ndim == 3 and b.ndim == 2:
3861+
utils.check(a.shape[2] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
3862+
out_shape = (a.shape[0], a.shape[1], b.shape[1])
3863+
elif a.ndim == 2 and b.ndim == 3:
3864+
utils.check(a.shape[1] == b.shape[1], lambda: f"Inner dimension mismatch: {a.shape} vs {b.shape}")
3865+
out_shape = (b.shape[0], a.shape[0], b.shape[2])
3866+
else:
3867+
utils.check(False, lambda: f"Unexpected shape combination: {a.shape} and {b.shape}")
3868+
3869+
# Validate scale tensors
3870+
# Scale tensors are typically 1D with shape matching the number of groups
3871+
# or they can be scalars
3872+
utils.check(
3873+
scale_a.ndim <= 1,
3874+
lambda: f"Expected scale_a to be a scalar or 1D tensor, got shape {scale_a.shape}",
3875+
)
3876+
utils.check(
3877+
scale_b.ndim <= 1,
3878+
lambda: f"Expected scale_b to be a scalar or 1D tensor, got shape {scale_b.shape}",
3879+
)
3880+
3881+
# Validate bias if provided
3882+
if bias is not None:
3883+
utils.check_type(bias, TensorProxy)
3884+
utils.check_same_device(a, bias)
3885+
utils.check_same_dtype(a, bias)
3886+
3887+
# Validate scale_result if provided
3888+
if scale_result is not None:
3889+
utils.check_type(scale_result, TensorProxy)
3890+
utils.check(
3891+
scale_result.ndim <= 1,
3892+
lambda: f"Expected scale_result to be a scalar or 1D tensor, got shape {scale_result.shape}",
3893+
)
3894+
3895+
utils.check_same_dtype(a, b)
3896+
utils.check(a.dtype in dtypes.float_math_dtypes, lambda: f"`a` must be 16-bit float or higher, got {a.dtype}")
3897+
if offsets is not None:
3898+
utils.check(utils.is_integer_dtype(offsets.dtype), lambda: f"`offsets` must be integers, got {offsets.dtype}")
3899+
3900+
utils.check_same_device(a, b, scale_a, scale_b)
3901+
if offsets is not None:
3902+
utils.check_same_device(a, offsets)
3903+
3904+
# Determine output dtype
3905+
result_dtype = out_dtype if out_dtype is not None else a.dtype
3906+
3907+
return TensorProxy(like=a, shape=out_shape, dtype=result_dtype)
3908+
3909+
3910+
scaled_grouped_mm = make_prim(
3911+
PrimIDs.SCALED_GROUPED_MM,
3912+
"scaled_grouped_mm",
3913+
meta=scaled_grouped_mm_meta,
3914+
tags=(OpTags.MATMUL_OP,),
3915+
)
3916+
3917+
37953918
def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorProxy:
37963919
utils.check_type(a, TensorProxy)
37973920
utils.check_type(permutation, tuple)

thunder/executors/nvfuserex_impl.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3231,6 +3231,78 @@ def _grouped_mm_transform(
32313231
register_supported(DTensorPrimIDs._GROUPED_MM, _grouped_mm_transform, _grouped_mm_check)
32323232

32333233

3234+
def _scaled_grouped_mm_check(
3235+
a: TensorProxy,
3236+
b: TensorProxy,
3237+
scale_a: TensorProxy,
3238+
scale_b: TensorProxy,
3239+
offsets: None | TensorProxy = None,
3240+
bias: None | TensorProxy = None,
3241+
scale_result: None | TensorProxy = None,
3242+
out_dtype: None | dtypes.dtype = None,
3243+
) -> bool:
3244+
# Check version requirement - scaled_grouped_mm likely requires same or newer version than grouped_mm
3245+
if nvfuser_version() < LooseVersion("0.2.28"):
3246+
return False
3247+
3248+
# Check all required tensors are supported
3249+
if not are_supported_tensors(a, b, scale_a, scale_b):
3250+
return False
3251+
3252+
# Check optional tensors if provided
3253+
if offsets is not None and not is_supported_tensor(offsets):
3254+
return False
3255+
if bias is not None and not is_supported_tensor(bias):
3256+
return False
3257+
if scale_result is not None and not is_supported_tensor(scale_result):
3258+
return False
3259+
3260+
# Check that nvfp4 is supported if used
3261+
if a.dtype == dtypes.float4_e2m1fn_x2 or b.dtype == dtypes.float4_e2m1fn_x2:
3262+
# nvfp4 requires nvFuser 0.2.28+ (already checked above)
3263+
# Additionally check device capability for fp8/fp4 support
3264+
if not device_supports_fp8():
3265+
return False
3266+
3267+
return True
3268+
3269+
3270+
def _scaled_grouped_mm_transform(
3271+
a: TensorProxy,
3272+
b: TensorProxy,
3273+
scale_a: TensorProxy,
3274+
scale_b: TensorProxy,
3275+
offsets: None | TensorProxy = None,
3276+
bias: None | TensorProxy = None,
3277+
scale_result: None | TensorProxy = None,
3278+
out_dtype: None | dtypes.dtype = None,
3279+
*,
3280+
fd: FusionDefinition,
3281+
lc_to_nv_map: dict,
3282+
) -> Any:
3283+
nva = getnv(a, fd, lc_to_nv_map)
3284+
nvb = getnv(b, fd, lc_to_nv_map)
3285+
nv_scale_a = getnv(scale_a, fd, lc_to_nv_map)
3286+
nv_scale_b = getnv(scale_b, fd, lc_to_nv_map)
3287+
nv_offsets = getnv(offsets, fd, lc_to_nv_map) if offsets is not None else None
3288+
nv_bias = getnv(bias, fd, lc_to_nv_map) if bias is not None else None
3289+
nv_scale_result = getnv(scale_result, fd, lc_to_nv_map) if scale_result is not None else None
3290+
3291+
# Translate out_dtype to nvFuser dtype if provided
3292+
nv_out_dtype = None
3293+
if out_dtype is not None:
3294+
nv_out_dtype = lcdtype_to_nvdtype(out_dtype)
3295+
3296+
# Call nvFuser's scaled_grouped_mm operation
3297+
# The API signature may vary, but typically includes all parameters
3298+
return fd.ops.scaled_grouped_mm(
3299+
nva, nvb, nv_scale_a, nv_scale_b, nv_offsets, nv_bias, nv_scale_result, nv_out_dtype
3300+
)
3301+
3302+
3303+
register_supported(prims.scaled_grouped_mm, _scaled_grouped_mm_transform, _scaled_grouped_mm_check)
3304+
3305+
32343306
def _cumsum_check(a: TensorProxy, dim: int, /, dtype: dtypes.dtype | None = None) -> bool:
32353307
if nvfuser_version() < LooseVersion("0.2.33") and a.ndim != 1:
32363308
return False

0 commit comments

Comments
 (0)