Skip to content

Commit 4ce95ac

Browse files
committed
Accommodate softmax to handle arbitrary axis dim length
1 parent 5169151 commit 4ce95ac

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

src/ntops/softmax.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55
import torch
66
from ninetoothed import Tensor
77

8+
BLOCK_SIZE = ninetoothed.block_size()
9+
810

911
def arrangement(input, output, dim):
1012
assert input.ndim == output.ndim
1113

1214
def create_axis_tile_shape(dim, dim_block):
13-
return tuple(1 for _ in range(dim)) + (dim_block,) + tuple(1 for _ in range(input.ndim - dim - 1))
14-
15-
inner_block_shape = create_axis_tile_shape(dim, input.shape[dim])
15+
return (
16+
tuple(1 for _ in range(dim))
17+
+ (dim_block,)
18+
+ tuple(1 for _ in range(input.ndim - dim - 1))
19+
)
20+
21+
inner_block_shape = create_axis_tile_shape(dim, BLOCK_SIZE)
1622
outer_block_shape = create_axis_tile_shape(dim, -1)
17-
23+
1824
def arrange(input):
1925
input_arranged = input.tile(inner_block_shape).tile(outer_block_shape)
2026

@@ -25,25 +31,36 @@ def arrange(input):
2531
tuple(d for d in range(input.ndim) if d != dim)
2632
)
2733
return input_arranged
28-
29-
input_arranged = arrange(input)
30-
output_arranged = arrange(output)
3134

32-
return input_arranged, output_arranged
35+
return arrange(input), arrange(output)
36+
37+
38+
def _exp(x, dtype):
39+
exp_dtype = dtype if dtype != ntl.float16 else ntl.float32
40+
return ntl.cast(ntl.exp(ntl.cast(x, exp_dtype)), dtype)
3341

3442

3543
def application(input, output):
44+
dtype = output.dtype.dtype
45+
prev_max = ntl.cast(float("-inf"), dtype)
46+
denominator = ntl.cast(0, dtype)
47+
48+
for i in range(input.shape[0]):
49+
input_i = ntl.cast(input[i], dtype)
50+
curr_max = ntl.cast(ntl.maximum(prev_max, ntl.max(input_i)), dtype)
51+
input_max_diff_exp = _exp(input_i - curr_max, dtype)
52+
prev_curr_max_diff_exp = _exp(prev_max - curr_max, dtype)
53+
denominator = denominator * prev_curr_max_diff_exp + ntl.sum(input_max_diff_exp)
54+
prev_max = curr_max
55+
3656
for i in range(input.shape[0]):
37-
input_i = input[i]
38-
row_minus_max = input_i - ntl.max(input_i)
39-
numerator = ntl.exp(ntl.cast(row_minus_max, ntl.float32))
40-
denominator = ntl.sum(numerator)
41-
output[i] = numerator / denominator # noqa: F841
57+
numerator = _exp(input[i] - prev_max, dtype)
58+
output[i] = numerator / denominator
4259

4360

44-
def softmax(input, dim, output=None):
45-
if output is None:
46-
output = torch.empty_like(input)
61+
def softmax(input, dim, dtype=None):
62+
tensor_dtype = dtype if dtype is not None else input.dtype
63+
output = torch.empty_like(input, dtype=tensor_dtype)
4764

4865
kernel = _make(input.ndim, dim)
4966

tests/test_softmax.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ def test_cuda(shape, dtype, atol, rtol):
1515

1616
input = torch.randn(shape, dtype=dtype, device=device)
1717
dim = random.randint(0, input.ndim - 1)
18+
dtype = random.choice([torch.float16, torch.float32, torch.float64])
1819

19-
ninetoothed_output = ntops.softmax(input, dim)
20-
reference_output = torch.nn.functional.softmax(input, dim=dim)
20+
ninetoothed_output = ntops.softmax(input, dim, dtype)
21+
reference_output = torch.nn.functional.softmax(input, dim=dim, dtype=dtype)
2122

2223
assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)

0 commit comments

Comments
 (0)