Skip to content

Commit 0fed754

Browse files
[github-action] formatting fixes
1 parent b979599 commit 0fed754

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

torchhd/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,12 +893,13 @@ def multiset(input: Tensor) -> Tensor:
893893

894894
return torch.sum(input, dim=dim, dtype=dtype)
895895

896+
896897
def randsel(
897898
input: Tensor, other: Tensor, *, p: float = 0.5, generator: torch.Generator = None
898899
) -> Tensor:
899900
r"""Bundles two hypervectors by selecting random elements.
900901
901-
A bundling operation is used to aggregate information into a single hypervector.
902+
A bundling operation is used to aggregate information into a single hypervector.
902903
The resulting hypervector has elements selected at random from input or other.
903904
904905
.. math::
@@ -979,6 +980,7 @@ def multirandsel(
979980
select.unsqueeze_(-2)
980981
return input.gather(-2, select).squeeze(-2)
981982

983+
982984
multibundle = multiset
983985

984986

torchhd/tests/test_operations.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,15 @@ def test_value(self):
299299
generator.manual_seed(2147483644)
300300

301301
x = functional.random_hv(4, 1000, generator=generator)
302-
res = functional.multirandsel(x, p=torch.tensor([0.0,0.0,1.0,0.0]), generator=generator)
302+
res = functional.multirandsel(
303+
x, p=torch.tensor([0.0, 0.0, 1.0, 0.0]), generator=generator
304+
)
303305
assert torch.all(x[2] == res)
304306

305307
x = functional.random_hv(4, 1000, generator=generator)
306-
res = functional.multirandsel(x, p=torch.tensor([0.5,0.0,0.5,0.0]), generator=generator)
308+
res = functional.multirandsel(
309+
x, p=torch.tensor([0.5, 0.0, 0.5, 0.0]), generator=generator
310+
)
307311
assert torch.all((x[0] == res) | (x[2] == res))
308312

309313
x = functional.random_hv(4, 1000, generator=generator)
@@ -325,4 +329,4 @@ def test_device(self):
325329
assert res.dtype == x.dtype
326330
assert res.dim() == 1
327331
assert res.size(0) == 100
328-
assert res.device == device
332+
assert res.device == device

0 commit comments

Comments
 (0)