Skip to content

Commit 91947dc

Browse files
authored
Merge pull request #5 from InfiniTensor/use-custom-gauss
Use custom `gauss`
2 parents 666c438 + 9aec4c9 commit 91947dc

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

tests/test_add.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
import random
2-
31
import pytest
42
import torch
53

64
import ntops
75
from tests.skippers import skip_if_cuda_not_available
8-
from tests.utils import generate_arguments
6+
from tests.utils import gauss, generate_arguments
97

108

119
@skip_if_cuda_not_available
@@ -15,7 +13,7 @@ def test_cuda(shape, dtype, atol, rtol):
1513

1614
input = torch.randn(shape, dtype=dtype, device=device)
1715
other = torch.randn(shape, dtype=dtype, device=device)
18-
alpha = random.gauss()
16+
alpha = gauss()
1917

2018
ninetoothed_output = ntops.add(input, other, alpha)
2119
reference_output = input + alpha * other

tests/test_addmm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import random
2-
31
import pytest
42
import torch
53

64
import ntops
75
from tests.skippers import skip_if_cuda_not_available
86
from tests.test_mm import generate_arguments
7+
from tests.utils import gauss
98

109

1110
@skip_if_cuda_not_available
@@ -16,8 +15,8 @@ def test_cuda(m, n, k, dtype, atol, rtol):
1615
input = torch.randn((m, n), dtype=dtype, device=device)
1716
x = torch.randn((m, k), dtype=dtype, device=device)
1817
y = torch.randn((k, n), dtype=dtype, device=device)
19-
beta = random.gauss()
20-
alpha = random.gauss()
18+
beta = gauss()
19+
alpha = gauss()
2120

2221
ninetoothed_output = ntops.addmm(input, x, y, beta=beta, alpha=alpha)
2322
reference_output = torch.addmm(input, x, y, beta=beta, alpha=alpha)

tests/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def generate_arguments():
2121
return "shape, dtype, atol, rtol", arguments
2222

2323

24+
def gauss(mu=0.0, sigma=1.0):
25+
return random.gauss(mu, sigma)
26+
27+
2428
def _random_shape(ndim, min_num_elements=2**8, max_num_elements=2**10):
2529
num_elements = random.randint(min_num_elements, max_num_elements)
2630

0 commit comments

Comments
 (0)