Skip to content

Commit 6830ef3

Browse files
committed
bump quack to 0.2.2
Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 43389de commit 6830ef3

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

thunder/executors/cutlass_dsl_ex.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,26 @@ def is_last_dim_divisible(dtype: dtypes.dtype, last_dim_size: int) -> bool:
8888
return last_dim_size % (128 // 8 // dtype.bytes) == 0
8989

9090

91+
quack_version: LooseVersion
92+
try:
93+
import quack
94+
except ImportError:
95+
quack_version = LooseVersion("0.0.0")
96+
else:
97+
quack_version = LooseVersion(quack.__version__)
98+
9199
# Register [`quack`](https://github.com/Dao-AILab/quack) ops
92-
if find_spec("quack") is not None:
100+
if quack_version >= LooseVersion("0.2.2"):
101+
import quack
102+
93103
# softmax
94-
from quack.softmax import _softmax_fwd, _softmax_backward
104+
from quack.softmax import softmax_fwd, softmax_bwd
95105

96106
def quack_softmax_impl(a: torch.Tensor) -> torch.Tensor:
97107
original_shape = a.shape
98108
if requires_reshpae := a.ndim > 2:
99109
a = a.view(-1, original_shape[-1])
100-
ret = _softmax_fwd(a)
110+
ret = softmax_fwd(a)
101111
if requires_reshpae:
102112
ret = ret.view(original_shape)
103113
return ret
@@ -116,7 +126,7 @@ def quack_softmax_backward(g: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
116126
if requires_reshape := g.ndim > 2:
117127
g = g.view(-1, original_shape[-1])
118128
a = a.view(-1, original_shape[-1])
119-
ret = _softmax_backward(g, a)
129+
ret = softmax_bwd(g, a)
120130
if requires_reshape:
121131
ret = ret.view(original_shape)
122132
return ret
@@ -196,13 +206,13 @@ def quack_softmax_transform(
196206
)
197207

198208
# crossentropy
199-
from quack.cross_entropy import _cross_entropy, _cross_entropy_backward
209+
from quack.cross_entropy import cross_entropy_fwd, cross_entropy_bwd
200210

201211
def quack_cross_entropy_forward_impl(
202212
x: torch.Tensor,
203213
target: torch.Tensor,
204214
) -> torch.Tensor:
205-
return _cross_entropy(x, target, return_lse=False)
215+
return cross_entropy_fwd(x, target, return_lse=False)
206216

207217
def quack_cross_entropy_forward_meta(x: TensorProxy, target: TensorProxy) -> TensorProxy:
208218
return TensorProxy(like=x, shape=(x.shape[0],), dtype=dtypes.float32)
@@ -219,7 +229,7 @@ def quack_cross_entropy_backward_impl(
219229
grad: torch.Tensor,
220230
lse: torch.Tensor,
221231
) -> torch.Tensor:
222-
return _cross_entropy_backward(x, target, grad, lse, False)
232+
return cross_entropy_bwd(x, target, grad, lse, False)
223233

224234
def quack_cross_entropy_backward_meta(
225235
x: TensorProxy,
@@ -290,7 +300,7 @@ def quack_cross_entropy_aug_forward_impl(
290300
x: torch.Tensor,
291301
target: torch.Tensor,
292302
) -> tuple[torch.Tensor, torch.Tensor]:
293-
return _cross_entropy(x, target, return_lse=True)
303+
return cross_entropy_fwd(x, target, return_lse=True)
294304

295305
def quack_cross_entropy_aug_forward_meta(a: TensorProxy, target: TensorProxy) -> tuple[TensorProxy, TensorProxy]:
296306
return (
@@ -378,7 +388,7 @@ def quack_layer_norm_checker(
378388
or weight.dtype not in {dtypes.float32}
379389
):
380390
return False
381-
return is_device_quack_compat()
391+
return is_device_quack_compat() and is_last_dim_divisible(a.dtype, a.shape[-1])
382392

383393
def quack_layer_norm_transform(
384394
a: TensorProxy,
@@ -397,7 +407,7 @@ def quack_layer_norm_transform(
397407
)
398408

399409
# rmsnorm
400-
from quack.rmsnorm import _rmsnorm_fwd, _rmsnorm_backward
410+
from quack.rmsnorm import rmsnorm_fwd, rmsnorm_bwd
401411

402412
def quack_rms_norm_forward_impl(
403413
x: torch.Tensor,
@@ -407,7 +417,7 @@ def quack_rms_norm_forward_impl(
407417
original_shape = x.shape
408418
if requires_reshape := x.ndim > 2:
409419
x = x.view(-1, original_shape[-1])
410-
ret = _rmsnorm_fwd(x, weight, eps, return_rstd=False)
420+
ret = rmsnorm_fwd(x, weight, eps=eps, store_rstd=False)[0]
411421
if requires_reshape:
412422
ret = ret.view(original_shape)
413423
return ret
@@ -435,7 +445,7 @@ def quack_rms_norm_backward_impl(
435445
if requires_reshape := grad.ndim > 2:
436446
grad = grad.view(-1, original_shape[-1])
437447
x = x.view(-1, original_shape[-1])
438-
ret = _rmsnorm_backward(x, weight, grad, rstd)
448+
ret = rmsnorm_bwd(x, weight, grad, rstd)
439449
if requires_reshape:
440450
ret = ret.view(original_shape)
441451
return ret
@@ -479,7 +489,7 @@ def quack_rms_norm_aug_forward_impl(
479489
original_shape = x.shape
480490
if requires_reshape := x.ndim > 2:
481491
x = x.view(-1, original_shape[-1])
482-
fwd, rstd = _rmsnorm_fwd(x, weight, eps, return_rstd=True)
492+
fwd, _, rstd = rmsnorm_fwd(x, weight, eps=eps, store_rstd=True)
483493
if requires_reshape:
484494
fwd = fwd.view(original_shape)
485495
return fwd, rstd

thunder/tests/test_cutlass_dsl_ex.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ def test_quack_softmax(dtype: torch.dtype, shape: tuple[int, ...]):
7878
@pytest.mark.parametrize("dtype", _DTYPES, ids=_DTYPE_IDS)
7979
def test_quack_layernorm(dtype: torch.dtype, shape: tuple[int, ...]):
8080
x = torch.randn(shape, dtype=dtype, requires_grad=True)
81-
ref_x = x.clone().detach().to(torch.float32)
81+
ref_x = x.clone().detach()
8282

83-
module = nn.LayerNorm(shape[-1]).cuda()
83+
module = nn.LayerNorm(shape[-1]).to(device="cuda", dtype=dtype)
8484
jitted = jit_with_cutlass_dsl_ex(module)
8585

86-
expected = module(ref_x).to(dtype)
86+
expected = module(ref_x)
8787
actual = jitted(x)
8888
torch.testing.assert_close(expected, actual)
8989

@@ -95,7 +95,7 @@ def test_quack_rmsnorm(dtype: torch.dtype, shape: tuple[int, ...]):
9595
x = torch.randn(shape, dtype=dtype, requires_grad=True)
9696
ref_x = x.clone().detach()
9797

98-
module = nn.RMSNorm(shape[-1]).cuda()
98+
module = nn.RMSNorm(shape[-1]).to(device="cuda", dtype=dtype)
9999
jitted = jit_with_cutlass_dsl_ex(module)
100100

101101
expected = module(ref_x)

0 commit comments

Comments
 (0)