Skip to content

Commit 471e377

Browse files
committed
Add gelu operator
1 parent c502e18 commit 471e377

File tree

3 files changed

+75
-1
lines changed

3 files changed

+75
-1
lines changed

src/ntops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ntops.addmm import addmm
44
from ntops.bmm import bmm
55
from ntops.div import div
6+
from ntops.gelu import gelu
67
from ntops.mm import mm
78

8-
__all__ = ["abs", "add", "addmm", "bmm", "div", "mm"]
9+
__all__ = ["abs", "add", "addmm", "bmm", "div", "gelu", "mm"]

src/ntops/gelu.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import functools
2+
import math
3+
4+
import ninetoothed
5+
import ninetoothed.language as ntl
6+
import torch
7+
from ninetoothed import Tensor
8+
9+
from ntops import element_wise
10+
11+
12+
def default_application(input, output):
13+
output = input * 0.5 * (1 + ntl.erf(input / ntl.sqrt(2.0))) # noqa: F841
14+
15+
16+
def tanh_application(input, output):
17+
input_loaded = input
18+
19+
output = ( # noqa: F841
20+
0.5
21+
* input_loaded
22+
* (
23+
1
24+
+ ntl.tanh(
25+
ntl.sqrt(2 / math.pi) * (input_loaded + 0.044715 * input_loaded**3)
26+
)
27+
)
28+
)
29+
30+
31+
def gelu(input, approximate="none"):
32+
output = torch.empty_like(input)
33+
34+
kernel = _make(input.ndim, approximate)
35+
36+
kernel(input, output)
37+
38+
return output
39+
40+
41+
@functools.cache
42+
def _make(ndim, approximate):
43+
tensors = (Tensor(ndim), Tensor(ndim))
44+
45+
if approximate == "tanh":
46+
application = tanh_application
47+
else:
48+
application = default_application
49+
50+
return ninetoothed.make(element_wise.arrangement, application, tensors)

tests/test_gelu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
import torch.nn.functional as F
4+
5+
import ntops
6+
from tests.skippers import skip_if_cuda_not_available
7+
from tests.utils import generate_arguments
8+
9+
10+
@skip_if_cuda_not_available
11+
@pytest.mark.parametrize(*generate_arguments())
12+
def test_cuda(shape, dtype, atol, rtol):
13+
device = "cuda"
14+
15+
input = torch.randn(shape, dtype=dtype, device=device)
16+
17+
for approximate in ("none", "tanh"):
18+
ninetoothed_output = ntops.gelu(input)
19+
reference_output = F.gelu(input)
20+
21+
assert torch.allclose(
22+
ninetoothed_output, reference_output, atol=atol, rtol=rtol
23+
)

0 commit comments

Comments
 (0)