@@ -88,16 +88,26 @@ def is_last_dim_divisible(dtype: dtypes.dtype, last_dim_size: int) -> bool:
8888 return last_dim_size % (128 // 8 // dtype .bytes ) == 0
8989
9090
91+ quack_version : LooseVersion
92+ try :
93+ import quack
94+ except ImportError :
95+ quack_version = LooseVersion ("0.0.0" )
96+ else :
97+ quack_version = LooseVersion (quack .__version__ )
98+
9199# Register [`quack`](https://github.com/Dao-AILab/quack) ops
92- if find_spec ("quack" ) is not None :
100+ if quack_version >= LooseVersion ("0.2.2" ):
101+ import quack
102+
93103 # softmax
94- from quack .softmax import _softmax_fwd , _softmax_backward
104+ from quack .softmax import softmax_fwd , softmax_bwd
95105
96106 def quack_softmax_impl (a : torch .Tensor ) -> torch .Tensor :
97107 original_shape = a .shape
98108 if requires_reshpae := a .ndim > 2 :
99109 a = a .view (- 1 , original_shape [- 1 ])
100- ret = _softmax_fwd (a )
110+ ret = softmax_fwd (a )
101111 if requires_reshpae :
102112 ret = ret .view (original_shape )
103113 return ret
@@ -116,7 +126,7 @@ def quack_softmax_backward(g: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
116126 if requires_reshape := g .ndim > 2 :
117127 g = g .view (- 1 , original_shape [- 1 ])
118128 a = a .view (- 1 , original_shape [- 1 ])
119- ret = _softmax_backward (g , a )
129+ ret = softmax_bwd (g , a )
120130 if requires_reshape :
121131 ret = ret .view (original_shape )
122132 return ret
@@ -196,13 +206,13 @@ def quack_softmax_transform(
196206 )
197207
198208 # crossentropy
199- from quack .cross_entropy import _cross_entropy , _cross_entropy_backward
209+ from quack .cross_entropy import cross_entropy_fwd , cross_entropy_bwd
200210
201211 def quack_cross_entropy_forward_impl (
202212 x : torch .Tensor ,
203213 target : torch .Tensor ,
204214 ) -> torch .Tensor :
205- return _cross_entropy (x , target , return_lse = False )
215+ return cross_entropy_fwd (x , target , return_lse = False )
206216
207217 def quack_cross_entropy_forward_meta (x : TensorProxy , target : TensorProxy ) -> TensorProxy :
208218 return TensorProxy (like = x , shape = (x .shape [0 ],), dtype = dtypes .float32 )
@@ -219,7 +229,7 @@ def quack_cross_entropy_backward_impl(
219229 grad : torch .Tensor ,
220230 lse : torch .Tensor ,
221231 ) -> torch .Tensor :
222- return _cross_entropy_backward (x , target , grad , lse , False )
232+ return cross_entropy_bwd (x , target , grad , lse , False )
223233
224234 def quack_cross_entropy_backward_meta (
225235 x : TensorProxy ,
@@ -290,7 +300,7 @@ def quack_cross_entropy_aug_forward_impl(
290300 x : torch .Tensor ,
291301 target : torch .Tensor ,
292302 ) -> tuple [torch .Tensor , torch .Tensor ]:
293- return _cross_entropy (x , target , return_lse = True )
303+ return cross_entropy_fwd (x , target , return_lse = True )
294304
295305 def quack_cross_entropy_aug_forward_meta (a : TensorProxy , target : TensorProxy ) -> tuple [TensorProxy , TensorProxy ]:
296306 return (
@@ -378,7 +388,7 @@ def quack_layer_norm_checker(
378388 or weight .dtype not in {dtypes .float32 }
379389 ):
380390 return False
381- return is_device_quack_compat ()
391+ return is_device_quack_compat () and is_last_dim_divisible ( a . dtype , a . shape [ - 1 ])
382392
383393 def quack_layer_norm_transform (
384394 a : TensorProxy ,
@@ -397,7 +407,7 @@ def quack_layer_norm_transform(
397407 )
398408
399409 # rmsnorm
400- from quack .rmsnorm import _rmsnorm_fwd , _rmsnorm_backward
410+ from quack .rmsnorm import rmsnorm_fwd , rmsnorm_bwd
401411
402412 def quack_rms_norm_forward_impl (
403413 x : torch .Tensor ,
@@ -407,7 +417,7 @@ def quack_rms_norm_forward_impl(
407417 original_shape = x .shape
408418 if requires_reshape := x .ndim > 2 :
409419 x = x .view (- 1 , original_shape [- 1 ])
410- ret = _rmsnorm_fwd (x , weight , eps , return_rstd = False )
420+ ret = rmsnorm_fwd (x , weight , eps = eps , store_rstd = False )[ 0 ]
411421 if requires_reshape :
412422 ret = ret .view (original_shape )
413423 return ret
@@ -435,7 +445,7 @@ def quack_rms_norm_backward_impl(
435445 if requires_reshape := grad .ndim > 2 :
436446 grad = grad .view (- 1 , original_shape [- 1 ])
437447 x = x .view (- 1 , original_shape [- 1 ])
438- ret = _rmsnorm_backward (x , weight , grad , rstd )
448+ ret = rmsnorm_bwd (x , weight , grad , rstd )
439449 if requires_reshape :
440450 ret = ret .view (original_shape )
441451 return ret
@@ -479,7 +489,7 @@ def quack_rms_norm_aug_forward_impl(
479489 original_shape = x .shape
480490 if requires_reshape := x .ndim > 2 :
481491 x = x .view (- 1 , original_shape [- 1 ])
482- fwd , rstd = _rmsnorm_fwd (x , weight , eps , return_rstd = True )
492+ fwd , _ , rstd = rmsnorm_fwd (x , weight , eps = eps , store_rstd = True )
483493 if requires_reshape :
484494 fwd = fwd .view (original_shape )
485495 return fwd , rstd
0 commit comments