Skip to content

Commit 3da459b

Browse files
committed
fix bug in linear
1 parent 77c480b commit 3da459b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

top/layers/linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import torch
23
from torch import nn
34
from top.functions import matmul
@@ -10,11 +11,15 @@ def __init__(
1011
M: int,
1112
N: int,
1213
K: int,
14+
device='cuda',
1315
dtype=torch.float16,
1416
tune=False,
1517
):
1618
super().__init__()
17-
self.weight = nn.Parameter(torch.Tensor(K, N), dtype=dtype)
19+
factory_kwargs = {"device": device, "dtype": dtype}
20+
self.weight = nn.Parameter(
21+
torch.empty((K, N), **factory_kwargs)
22+
)
1823
self.fn = matmul(
1924
M,
2025
N,

0 commit comments

Comments
 (0)