Skip to content

Commit 1e5f3b2

Browse files
committed
mandate weight in layer|rms norm
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent b4408d1 commit 1e5f3b2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

thunder/executors/cutlass_dsl_ex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def quack_layer_norm_checker(
373373
) -> bool:
374374
if (
375375
a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}
376-
or weight.ndim != 1
376+
or (weight is None or weight.ndim != 1)
377377
or a.shape[-1] != weight.shape[0]
378378
or weight.dtype not in {dtypes.float32}
379379
):
@@ -463,7 +463,7 @@ def quack_rms_norm_checker(
463463
eps: float | None = None,
464464
) -> bool:
465465
if (
466-
weight.ndim != 1
466+
(weight is None or weight.ndim != 1)
467467
or a.shape[-1] != weight.shape[0]
468468
or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}
469469
or weight.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32}

0 commit comments

Comments
 (0)