@@ -270,6 +270,7 @@ class PrimIDs(Enum):
270270 # Linear algebra prims (Mostly experimental)
271271 MATMUL = auto ()
272272 _GROUPED_MM = auto () # Used for grouped matmuls
273+ SCALED_GROUPED_MM = auto () # Used for scaled grouped matmuls
273274 # NN prims (Experimental!)
274275 CONVOLUTION = auto ()
275276 EMBEDDING = auto ()
@@ -3792,6 +3793,128 @@ def _grouped_mm_meta(a: TensorProxy, b: TensorProxy, offsets: TensorProxy) -> Te
37923793)
37933794
37943795
3796+ def scaled_grouped_mm_meta (
3797+ a : TensorProxy ,
3798+ b : TensorProxy ,
3799+ scale_a : TensorProxy ,
3800+ scale_b : TensorProxy ,
3801+ offsets : None | TensorProxy = None ,
3802+ bias : None | TensorProxy = None ,
3803+ scale_result : None | TensorProxy = None ,
3804+ out_dtype : None | dtypes .dtype = None ,
3805+ ) -> TensorProxy :
3806+ """Meta function for scaled_grouped_mm primitive.
3807+
3808+ Similar to _grouped_mm but with scale tensors for quantization/dequantization.
3809+ Accepts the following shape combinations:
3810+ 1. (m, k) x (k, n) -> (groups, m, n)
3811+ 2. (groups, m, k) x (k, n) -> (m, n)
3812+ 3. (m, k) x (groups, k, n) -> (m, n)
3813+
3814+ Args:
3815+ a: Input tensor of shape (groups, m, k) or (m, k)
3816+ b: Input tensor of shape (groups, k, n) or (k, n)
3817+ scale_a: Scale tensor for a
3818+ scale_b: Scale tensor for b
3819+ offsets: Optional offset tensor of shape (groups,)
3820+ bias: Optional bias tensor
3821+ scale_result: Optional scale tensor for result
3822+ out_dtype: Optional output dtype
3823+
3824+ Returns:
3825+ TensorProxy with shape (groups, m, n) or (m, n)
3826+ """
3827+ # Validate types
3828+ utils .check_type (a , TensorProxy )
3829+ utils .check_type (b , TensorProxy )
3830+ utils .check_type (scale_a , TensorProxy )
3831+ utils .check_type (scale_b , TensorProxy )
3832+
3833+ # Accept 2D or 3D tensors
3834+ utils .check (a .ndim in (2 , 3 ), lambda : f"Expected a to have 2 or 3 dimensions, got { a .ndim } " )
3835+ utils .check (b .ndim in (2 , 3 ), lambda : f"Expected b to have 2 or 3 dimensions, got { b .ndim } " )
3836+
3837+ # Compute output shape using same logic as _grouped_mm
3838+ if offsets is not None :
3839+ utils .check_type (offsets , TensorProxy )
3840+ utils .check (offsets .ndim == 1 , lambda : f"`offsets` must be a vector, got shape { offsets .shape } " )
3841+
3842+ if a .ndim == 2 and b .ndim == 2 :
3843+ utils .check (a .shape [1 ] == b .shape [0 ], lambda : f"Inner dimension mismatch: { a .shape } vs { b .shape } " )
3844+ out_shape = (offsets .shape [0 ], a .shape [0 ], b .shape [1 ])
3845+ elif a .ndim == 3 and b .ndim == 2 :
3846+ utils .check (a .shape [2 ] == b .shape [1 ], lambda : f"Inner dimension mismatch: { a .shape } vs { b .shape } " )
3847+ utils .check (a .shape [0 ] == offsets .shape [0 ], lambda : f"Group count mismatch: { a .shape } vs { offsets .shape } " )
3848+ out_shape = (a .shape [1 ], b .shape [1 ])
3849+ elif a .ndim == 2 and b .ndim == 3 :
3850+ utils .check (a .shape [1 ] == b .shape [1 ], lambda : f"Inner dimension mismatch: { a .shape } vs { b .shape } " )
3851+ utils .check (b .shape [0 ] == offsets .shape [0 ], lambda : f"Group count mismatch: { b .shape } vs { offsets .shape } " )
3852+ out_shape = (a .shape [0 ], b .shape [2 ])
3853+ else :
3854+ utils .check (False , lambda : f"Unexpected shape combination: { a .shape } and { b .shape } " )
3855+ else :
3856+ # Without offsets, fall back to standard matmul shape logic
3857+ if a .ndim == 2 and b .ndim == 2 :
3858+ utils .check (a .shape [1 ] == b .shape [0 ], lambda : f"Inner dimension mismatch: { a .shape } vs { b .shape } " )
3859+ out_shape = (a .shape [0 ], b .shape [1 ])
3860+ elif a .ndim == 3 and b .ndim == 2 :
3861+ utils .check (a .shape [2 ] == b .shape [1 ], lambda : f"Inner dimension mismatch: { a .shape } vs { b .shape } " )
3862+ out_shape = (a .shape [0 ], a .shape [1 ], b .shape [1 ])
3863+ elif a .ndim == 2 and b .ndim == 3 :
3864+ utils .check (a .shape [1 ] == b .shape [1 ], lambda : f"Inner dimension mismatch: { a .shape } vs { b .shape } " )
3865+ out_shape = (b .shape [0 ], a .shape [0 ], b .shape [2 ])
3866+ else :
3867+ utils .check (False , lambda : f"Unexpected shape combination: { a .shape } and { b .shape } " )
3868+
3869+ # Validate scale tensors
3870+ # Scale tensors are typically 1D with shape matching the number of groups
3871+ # or they can be scalars
3872+ utils .check (
3873+ scale_a .ndim <= 1 ,
3874+ lambda : f"Expected scale_a to be a scalar or 1D tensor, got shape { scale_a .shape } " ,
3875+ )
3876+ utils .check (
3877+ scale_b .ndim <= 1 ,
3878+ lambda : f"Expected scale_b to be a scalar or 1D tensor, got shape { scale_b .shape } " ,
3879+ )
3880+
3881+ # Validate bias if provided
3882+ if bias is not None :
3883+ utils .check_type (bias , TensorProxy )
3884+ utils .check_same_device (a , bias )
3885+ utils .check_same_dtype (a , bias )
3886+
3887+ # Validate scale_result if provided
3888+ if scale_result is not None :
3889+ utils .check_type (scale_result , TensorProxy )
3890+ utils .check (
3891+ scale_result .ndim <= 1 ,
3892+ lambda : f"Expected scale_result to be a scalar or 1D tensor, got shape { scale_result .shape } " ,
3893+ )
3894+
3895+ utils .check_same_dtype (a , b )
3896+ utils .check (a .dtype in dtypes .float_math_dtypes , lambda : f"`a` must be 16-bit float or higher, got { a .dtype } " )
3897+ if offsets is not None :
3898+ utils .check (utils .is_integer_dtype (offsets .dtype ), lambda : f"`offsets` must be integers, got { offsets .dtype } " )
3899+
3900+ utils .check_same_device (a , b , scale_a , scale_b )
3901+ if offsets is not None :
3902+ utils .check_same_device (a , offsets )
3903+
3904+ # Determine output dtype
3905+ result_dtype = out_dtype if out_dtype is not None else a .dtype
3906+
3907+ return TensorProxy (like = a , shape = out_shape , dtype = result_dtype )
3908+
3909+
3910+ scaled_grouped_mm = make_prim (
3911+ PrimIDs .SCALED_GROUPED_MM ,
3912+ "scaled_grouped_mm" ,
3913+ meta = scaled_grouped_mm_meta ,
3914+ tags = (OpTags .MATMUL_OP ,),
3915+ )
3916+
3917+
37953918def transpose_meta (a : TensorProxy , / , permutation : tuple [int , ...]) -> TensorProxy :
37963919 utils .check_type (a , TensorProxy )
37973920 utils .check_type (permutation , tuple )
0 commit comments