Skip to content

Commit 08fd6f8

Browse files
authored
Fix incorrect variable usage in EMG example (#136)
* EMG example fix * Torch 2.0.0 now implements prod for bfloat16 on CPU
1 parent 7267c47 commit 08fd6f8

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

examples/emg_hand_gestures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, out_features, timestamps, channels):
3939
def forward(self, input: torch.Tensor) -> torch.Tensor:
4040
signal = self.signals(input)
4141
samples = torchhd.bind(signal, self.channels.weight.unsqueeze(0))
42-
samples = torchhd.bind(signal, self.timestamps.weight.unsqueeze(1))
42+
samples = torchhd.bind(samples, self.timestamps.weight.unsqueeze(1))
4343

4444
samples = torchhd.multiset(samples)
4545
sample_hv = torchhd.ngrams(samples, n=N_GRAM_SIZE)

torchhd/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,7 +1221,7 @@ def ngrams(input: VSATensor, n: int = 3) -> VSATensor:
12211221
\bigoplus_{i=0}^{m - n} \bigotimes_{j = 0}^{n - 1} \Pi^{n - j - 1}(V_{i + j})
12221222
12231223
.. note::
1224-
For :math:`n=1` use :func:`~torchhd.functional.multiset` instead and for :math:`n=m` use :func:`~torchhd.functional.bind_sequence` instead.
1224+
For :math:`n=1` use :func:`~torchhd.multiset` instead and for :math:`n=m` use :func:`~torchhd.bind_sequence` instead.
12251225
12261226
Args:
12271227
input (VSATensor): The value hypervectors.

torchhd/tests/test_encodings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_value(self, vsa, dtype):
140140
def test_dtype(self, dtype):
141141
hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor)
142142

143-
if dtype in {torch.float16, torch.bfloat16}:
143+
if dtype in {torch.float16}:
144144
# torch.product is not implemented on CPU for these dtypes
145145
with pytest.raises(RuntimeError):
146146
functional.multibind(hv)
@@ -287,7 +287,7 @@ def test_value(self):
287287
def test_dtype(self, dtype):
288288
hv = torch.zeros(23, 1000, dtype=dtype).as_subclass(MAPTensor)
289289

290-
if dtype in {torch.float16, torch.bfloat16}:
290+
if dtype in {torch.float16}:
291291
# torch.product is not implemented on CPU for these dtypes
292292
with pytest.raises(RuntimeError):
293293
functional.multibind(hv)

0 commit comments

Comments
 (0)