55import torch
66from ninetoothed import Tensor
77
8+ BLOCK_SIZE = ninetoothed .block_size ()
9+
810
911def arrangement (input , output , dim ):
1012 assert input .ndim == output .ndim
1113
1214 def create_axis_tile_shape (dim , dim_block ):
13- return tuple (1 for _ in range (dim )) + (dim_block ,) + tuple (1 for _ in range (input .ndim - dim - 1 ))
14-
15- inner_block_shape = create_axis_tile_shape (dim , input .shape [dim ])
15+ return (
16+ tuple (1 for _ in range (dim ))
17+ + (dim_block ,)
18+ + tuple (1 for _ in range (input .ndim - dim - 1 ))
19+ )
20+
21+ inner_block_shape = create_axis_tile_shape (dim , BLOCK_SIZE )
1622 outer_block_shape = create_axis_tile_shape (dim , - 1 )
17-
23+
1824 def arrange (input ):
1925 input_arranged = input .tile (inner_block_shape ).tile (outer_block_shape )
2026
@@ -25,25 +31,36 @@ def arrange(input):
2531 tuple (d for d in range (input .ndim ) if d != dim )
2632 )
2733 return input_arranged
28-
29- input_arranged = arrange (input )
30- output_arranged = arrange (output )
3134
32- return input_arranged , output_arranged
35+ return arrange (input ), arrange (output )
36+
37+
38+ def _exp (x , dtype ):
39+ exp_dtype = dtype if dtype != ntl .float16 else ntl .float32
40+ return ntl .cast (ntl .exp (ntl .cast (x , exp_dtype )), dtype )
3341
3442
3543def application (input , output ):
44+ dtype = output .dtype .dtype
45+ prev_max = ntl .cast (float ("-inf" ), dtype )
46+ denominator = ntl .cast (0 , dtype )
47+
48+ for i in range (input .shape [0 ]):
49+ input_i = ntl .cast (input [i ], dtype )
50+ curr_max = ntl .cast (ntl .maximum (prev_max , ntl .max (input_i )), dtype )
51+ input_max_diff_exp = _exp (input_i - curr_max , dtype )
52+ prev_curr_max_diff_exp = _exp (prev_max - curr_max , dtype )
53+ denominator = denominator * prev_curr_max_diff_exp + ntl .sum (input_max_diff_exp )
54+ prev_max = curr_max
55+
3656 for i in range (input .shape [0 ]):
37- input_i = input [i ]
38- row_minus_max = input_i - ntl .max (input_i )
39- numerator = ntl .exp (ntl .cast (row_minus_max , ntl .float32 ))
40- denominator = ntl .sum (numerator )
41- output [i ] = numerator / denominator # noqa: F841
57+ numerator = _exp (input [i ] - prev_max , dtype )
58+ output [i ] = numerator / denominator
4259
4360
44- def softmax (input , dim , output = None ):
45- if output is None :
46- output = torch .empty_like (input )
61+ def softmax (input , dim , dtype = None ):
62+ tensor_dtype = dtype if dtype is not None else input . dtype
63+ output = torch .empty_like (input , dtype = tensor_dtype )
4764
4865 kernel = _make (input .ndim , dim )
4966
0 commit comments