Skip to content

Commit 98a1b7d

Browse files
committed
improve check
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 6830ef3 commit 98a1b7d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

thunder/executors/cutlass_dsl_ex.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def is_last_dim_divisible(dtype: dtypes.dtype, last_dim_size: int) -> bool:
105105

106106
def quack_softmax_impl(a: torch.Tensor) -> torch.Tensor:
107107
original_shape = a.shape
108-
if requires_reshpae := a.ndim > 2:
108+
if requires_reshape := a.ndim > 2:
109109
a = a.view(-1, original_shape[-1])
110110
ret = softmax_fwd(a)
111-
if requires_reshpae:
111+
if requires_reshape:
112112
ret = ret.view(original_shape)
113113
return ret
114114

@@ -273,11 +273,11 @@ def quack_cross_entropy_checker(
273273
if label_smoothing != 0.0:
274274
return False
275275

276-
if (
277-
a.ndim != 2
278-
or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}
276+
if not (
277+
a.ndim == 2
278+
and a.dtype in {dtypes.float16, dtypes.bfloat16, dtypes.float32}
279279
and target.ndim == 1
280-
and target.dytpe in {dtypes.int32, dtypes.int64}
280+
and target.dtype in {dtypes.int32, dtypes.int64}
281281
):
282282
return False
283283

@@ -383,7 +383,7 @@ def quack_layer_norm_checker(
383383
) -> bool:
384384
if (
385385
a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}
386-
or weight.ndim != 1
386+
or (weight is None or weight.ndim != 1)
387387
or a.shape[-1] != weight.shape[0]
388388
or weight.dtype not in {dtypes.float32}
389389
):
@@ -473,7 +473,7 @@ def quack_rms_norm_checker(
473473
eps: float | None = None,
474474
) -> bool:
475475
if (
476-
weight.ndim != 1
476+
(weight is None or weight.ndim != 1)
477477
or a.shape[-1] != weight.shape[0]
478478
or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}
479479
or weight.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}

0 commit comments

Comments
 (0)