@@ -105,10 +105,10 @@ def is_last_dim_divisible(dtype: dtypes.dtype, last_dim_size: int) -> bool:
105105
106106 def quack_softmax_impl (a : torch .Tensor ) -> torch .Tensor :
107107 original_shape = a .shape
108- if requires_reshpae := a .ndim > 2 :
108+ if requires_reshape := a .ndim > 2 :
109109 a = a .view (- 1 , original_shape [- 1 ])
110110 ret = softmax_fwd (a )
111- if requires_reshpae :
111+ if requires_reshape :
112112 ret = ret .view (original_shape )
113113 return ret
114114
@@ -273,11 +273,11 @@ def quack_cross_entropy_checker(
273273 if label_smoothing != 0.0 :
274274 return False
275275
276- if (
277- a .ndim ! = 2
278- or a .dtype not in {dtypes .float16 , dtypes .bfloat16 , dtypes .float32 }
276+ if not (
277+ a .ndim = = 2
278+ and a .dtype in {dtypes .float16 , dtypes .bfloat16 , dtypes .float32 }
279279 and target .ndim == 1
280- and target .dytpe in {dtypes .int32 , dtypes .int64 }
280+ and target .dtype in {dtypes .int32 , dtypes .int64 }
281281 ):
282282 return False
283283
@@ -383,7 +383,7 @@ def quack_layer_norm_checker(
383383 ) -> bool :
384384 if (
385385 a .dtype not in {dtypes .float16 , dtypes .bfloat16 , dtypes .float32 }
386- or weight .ndim != 1
386+ or ( weight is None or weight .ndim != 1 )
387387 or a .shape [- 1 ] != weight .shape [0 ]
388388 or weight .dtype not in {dtypes .float32 }
389389 ):
@@ -473,7 +473,7 @@ def quack_rms_norm_checker(
473473 eps : float | None = None ,
474474 ) -> bool :
475475 if (
476- weight .ndim != 1
476+ ( weight is None or weight .ndim != 1 )
477477 or a .shape [- 1 ] != weight .shape [0 ]
478478 or a .dtype not in {dtypes .float16 , dtypes .bfloat16 , dtypes .float32 }
479479 or weight .dtype not in {dtypes .float16 , dtypes .bfloat16 , dtypes .float32 }
0 commit comments