Skip to content

Commit 81691a6

Browse files
Allow custom dtype for some VSA similarity functions (#160)
* Allow custom dtype for some VSA similarity functions * [github-action] formatting fixes * Increase tolarance for test * [github-action] formatting fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 27ae2da commit 81691a6

File tree

8 files changed

+61
-18
lines changed

8 files changed

+61
-18
lines changed

torchhd/functional.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ def hard_quantize(input: Tensor):
893893
return torch.where(input > 0, positive, negative)
894894

895895

896-
def dot_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
896+
def dot_similarity(input: VSATensor, others: VSATensor, **kwargs) -> VSATensor:
897897
"""Dot product between the input vector and each vector in others.
898898
899899
Aliased as ``torchhd.dot``.
@@ -938,13 +938,13 @@ def dot_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
938938
"""
939939
input = ensure_vsa_tensor(input)
940940
others = ensure_vsa_tensor(others)
941-
return input.dot_similarity(others)
941+
return input.dot_similarity(others, **kwargs)
942942

943943

944944
dot = dot_similarity
945945

946946

947-
def cosine_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
947+
def cosine_similarity(input: VSATensor, others: VSATensor, **kwargs) -> VSATensor:
948948
"""Cosine similarity between the input vector and each vector in others.
949949
950950
Aliased as ``torchhd.cos``.
@@ -987,7 +987,7 @@ def cosine_similarity(input: VSATensor, others: VSATensor) -> VSATensor:
987987
"""
988988
input = ensure_vsa_tensor(input)
989989
others = ensure_vsa_tensor(others)
990-
return input.cosine_similarity(others)
990+
return input.cosine_similarity(others, **kwargs)
991991

992992

993993
cos = cosine_similarity

torchhd/tensors/bsbc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,20 +334,21 @@ def permute(self, shifts: int = 1) -> "BSBCTensor":
334334
"""
335335
return torch.roll(self, shifts=shifts, dims=-1)
336336

337-
def dot_similarity(self, others: "BSBCTensor") -> Tensor:
337+
def dot_similarity(self, others: "BSBCTensor", *, dtype=None) -> Tensor:
338338
"""Inner product with other hypervectors"""
339-
dtype = torch.get_default_dtype()
339+
if dtype is None:
340+
dtype = torch.get_default_dtype()
340341

341342
if self.dim() > 1 and others.dim() > 1:
342343
equals = self.unsqueeze(-2) == others.unsqueeze(-3)
343344
return torch.sum(equals, dim=-1, dtype=dtype)
344345

345346
return torch.sum(self == others, dim=-1, dtype=dtype)
346347

347-
def cosine_similarity(self, others: "BSBCTensor") -> Tensor:
348+
def cosine_similarity(self, others: "BSBCTensor", *, dtype=None) -> Tensor:
348349
"""Cosine similarity with other hypervectors"""
349350
magnitude = self.size(-1)
350-
return self.dot_similarity(others) / magnitude
351+
return self.dot_similarity(others, dtype=dtype) / magnitude
351352

352353
@classmethod
353354
def __torch_function__(cls, func, types, args=(), kwargs=None):

torchhd/tensors/bsc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,11 @@ def permute(self, shifts: int = 1) -> "BSCTensor":
426426
"""
427427
return super().roll(shifts=shifts, dims=-1)
428428

429-
def dot_similarity(self, others: "BSCTensor") -> Tensor:
429+
def dot_similarity(self, others: "BSCTensor", *, dtype=None) -> Tensor:
430430
"""Inner product with other hypervectors."""
431-
dtype = torch.get_default_dtype()
432431
device = self.device
432+
if dtype is None:
433+
dtype = torch.get_default_dtype()
433434

434435
min_one = torch.tensor(-1.0, dtype=dtype, device=device)
435436
plus_one = torch.tensor(1.0, dtype=dtype, device=device)
@@ -441,7 +442,7 @@ def dot_similarity(self, others: "BSCTensor") -> Tensor:
441442
others_as_bipolar = others_as_bipolar.transpose(-2, -1)
442443
return torch.matmul(self_as_bipolar, others_as_bipolar)
443444

444-
def cosine_similarity(self, others: "BSCTensor") -> Tensor:
445+
def cosine_similarity(self, others: "BSCTensor", *, dtype=None) -> Tensor:
445446
"""Cosine similarity with other hypervectors."""
446447
d = self.size(-1)
447-
return self.dot_similarity(others) / d
448+
return self.dot_similarity(others, dtype=dtype) / d

torchhd/tensors/fhrr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def dot_similarity(self, others: "FHRRTensor") -> Tensor:
379379
"""Inner product with other hypervectors"""
380380
if others.dim() >= 2:
381381
others = others.transpose(-2, -1)
382+
382383
return torch.real(torch.matmul(self, torch.conj(others)))
383384

384385
def cosine_similarity(self, others: "FHRRTensor", *, eps=1e-08) -> Tensor:

torchhd/tensors/hrr.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def dot_similarity(self, others: "HRRTensor") -> Tensor:
366366
"""Inner product with other hypervectors"""
367367
if others.dim() >= 2:
368368
others = others.transpose(-2, -1)
369+
369370
return torch.matmul(self, others)
370371

371372
def cosine_similarity(self, others: "HRRTensor", *, eps=1e-08) -> Tensor:

torchhd/tensors/map.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,16 +340,22 @@ def clipping(self, kappa) -> "MAPTensor":
340340

341341
return torch.clamp(self, min=-kappa, max=kappa)
342342

343-
def dot_similarity(self, others: "MAPTensor") -> Tensor:
343+
def dot_similarity(self, others: "MAPTensor", *, dtype=None) -> Tensor:
344344
"""Inner product with other hypervectors"""
345-
dtype = torch.get_default_dtype()
345+
if dtype is None:
346+
dtype = torch.get_default_dtype()
347+
346348
if others.dim() >= 2:
347349
others = others.transpose(-2, -1)
350+
348351
return torch.matmul(self.to(dtype), others.to(dtype))
349352

350-
def cosine_similarity(self, others: "MAPTensor", *, eps=1e-08) -> Tensor:
353+
def cosine_similarity(
354+
self, others: "MAPTensor", *, dtype=None, eps=1e-08
355+
) -> Tensor:
351356
"""Cosine similarity with other hypervectors"""
352-
dtype = torch.get_default_dtype()
357+
if dtype is None:
358+
dtype = torch.get_default_dtype()
353359

354360
self_dot = torch.sum(self * self, dim=-1, dtype=dtype)
355361
self_mag = torch.sqrt(self_dot)
@@ -363,4 +369,4 @@ def cosine_similarity(self, others: "MAPTensor", *, eps=1e-08) -> Tensor:
363369
magnitude = self_mag * others_mag
364370

365371
magnitude = torch.clamp(magnitude, min=eps)
366-
return self.dot_similarity(others) / magnitude
372+
return self.dot_similarity(others, dtype=dtype) / magnitude

torchhd/tests/basis_hv/test_circular_hv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def test_value(self, dtype, vsa):
114114

115115
elif vsa == "FHRR":
116116
mag = hv.abs()
117-
assert torch.allclose(mag, torch.tensor(1.0, dtype=mag.dtype))
117+
assert torch.allclose(
118+
mag, torch.tensor(1.0, dtype=mag.dtype), rtol=0.0001, atol=0.0001
119+
)
118120

119121
elif vsa == "BSBC":
120122
assert torch.all((hv >= 0) | (hv < 1024)).item()

torchhd/tests/test_similarities.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,37 @@ def test_dtype(self, vsa, dtype):
198198
else:
199199
assert similarity.dtype == torch.get_default_dtype()
200200

201+
def test_custom_dtype(self):
202+
hv = functional.random(3, 100, "BSBC", block_size=1024)
203+
similarity = functional.dot_similarity(hv, hv)
204+
assert similarity.dtype == torch.get_default_dtype()
205+
206+
similarity = functional.dot_similarity(hv, hv, dtype=torch.float64)
207+
assert similarity.dtype == torch.float64
208+
209+
similarity = functional.dot_similarity(hv, hv, dtype=torch.int16)
210+
assert similarity.dtype == torch.int16
211+
212+
hv = functional.random(3, 100, "MAP")
213+
similarity = functional.dot_similarity(hv, hv)
214+
assert similarity.dtype == torch.get_default_dtype()
215+
216+
similarity = functional.dot_similarity(hv, hv, dtype=torch.float64)
217+
assert similarity.dtype == torch.float64
218+
219+
similarity = functional.dot_similarity(hv, hv, dtype=torch.int16)
220+
assert similarity.dtype == torch.int16
221+
222+
hv = functional.random(3, 100, "BSC")
223+
similarity = functional.dot_similarity(hv, hv)
224+
assert similarity.dtype == torch.get_default_dtype()
225+
226+
similarity = functional.dot_similarity(hv, hv, dtype=torch.float64)
227+
assert similarity.dtype == torch.float64
228+
229+
similarity = functional.dot_similarity(hv, hv, dtype=torch.int16)
230+
assert similarity.dtype == torch.int16
231+
201232
@pytest.mark.parametrize("vsa", vsa_tensors)
202233
@pytest.mark.parametrize("dtype", torch_dtypes)
203234
def test_device(self, vsa, dtype):

0 commit comments

Comments
 (0)