Skip to content

Commit bbe64a3

Browse files
authored
Fix usage of new .mT to older .transpose() (#138)
* Refactor .mT to transpose(-2, -1) * Add VSA to setup description
1 parent cb21e28 commit bbe64a3

File tree

6 files changed

+11
-7
lines changed

6 files changed

+11
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
setup(
1414
name="torch-hd", # use torch-hd on PyPi to install torchhd, torchhd is too similar according to PyPi
1515
version=version["__version__"],
16-
description="Torchhd is a Python library for Hyperdimensional Computing",
16+
description="Torchhd is a Python library for Hyperdimensional Computing and Vector Symbolic Architectures",
1717
long_description=open("README.md").read(),
1818
long_description_content_type="text/markdown",
1919
url="https://github.com/hyperdimensional-computing/torchhd",

torchhd/functional.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,7 +1562,7 @@ def resonator(input: VSATensor, estimates: VSATensor, domains: VSATensor) -> VSA
15621562
new_estimates = bind(input.unsqueeze(-2), inv_others)
15631563

15641564
similarity = dot_similarity(new_estimates.unsqueeze(-2), domains)
1565-
output = dot_similarity(similarity, domains.mT).squeeze(-2)
1565+
output = dot_similarity(similarity, domains.transpose(-2, -1)).squeeze(-2)
15661566

15671567
# normalize the output vector with a non-linearity
15681568
return output.sign()
@@ -1591,7 +1591,11 @@ def ridge_regression(
15911591

15921592
variance = alpha * torch.diag(torch.var(samples, -2))
15931593

1594-
return labels.mT @ samples @ torch.linalg.pinv(samples.mT @ samples + variance)
1594+
return (
1595+
labels.transpose(-2, -1)
1596+
@ samples
1597+
@ torch.linalg.pinv(samples.transpose(-2, -1) @ samples + variance)
1598+
)
15951599

15961600

15971601
def map_range(

torchhd/tensors/bsc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def dot_similarity(self, others: "BSCTensor") -> Tensor:
438438
others_as_bipolar = torch.where(others.bool(), min_one, plus_one)
439439

440440
if others.dim() >= 2:
441-
others_as_bipolar = others_as_bipolar.mT
441+
others_as_bipolar = others_as_bipolar.transpose(-2, -1)
442442
return torch.matmul(self_as_bipolar, others_as_bipolar)
443443

444444
def cosine_similarity(self, others: "BSCTensor") -> Tensor:

torchhd/tensors/fhrr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def permute(self, shifts: int = 1) -> "FHRRTensor":
378378
def dot_similarity(self, others: "FHRRTensor") -> Tensor:
379379
"""Inner product with other hypervectors"""
380380
if others.dim() >= 2:
381-
others = others.mT
381+
others = others.transpose(-2, -1)
382382
return torch.real(torch.matmul(self, torch.conj(others)))
383383

384384
def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:

torchhd/tensors/hrr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def permute(self, shifts: int = 1) -> "HRRTensor":
365365
def dot_similarity(self, others: "HRRTensor") -> Tensor:
366366
"""Inner product with other hypervectors"""
367367
if others.dim() >= 2:
368-
others = others.mT
368+
others = others.transpose(-2, -1)
369369
return torch.matmul(self, others)
370370

371371
def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:

torchhd/tensors/map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def dot_similarity(self, others: "MAPTensor") -> Tensor:
344344
"""Inner product with other hypervectors"""
345345
dtype = torch.get_default_dtype()
346346
if others.dim() >= 2:
347-
others = others.mT
347+
others = others.transpose(-2, -1)
348348
return torch.matmul(self.to(dtype), others.to(dtype))
349349

350350
def cosine_similarity(self, others: "MAPTensor", *, eps=1e-08) -> Tensor:

0 commit comments

Comments
 (0)