Skip to content

Commit 8c35951

Browse files
Fix dtype validation error in the Fractional Power Encoding (#148)
* Fix dtype check and add PR template * Update template * Fix dtype checking * [github-action] formatting fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 77c9b91 commit 8c35951

File tree

3 files changed

+130
-3
lines changed

3 files changed

+130
-3
lines changed

.github/pull_request_template.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<!--- Provide a short summary of your changes in the Title above -->
2+
3+
## Description
4+
<!--- Describe your changes in detail -->
5+
<!-- Link the issue (if any) that will be resolved by the changes -->
6+
7+
8+
9+
## Checklist
10+
- [ ] I added/updated documentation for the changes.
11+
- [ ] I have thoroughly tested the changes.

torchhd/embeddings.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import torchhd.functional as functional
3333
from torchhd.tensors.base import VSATensor
3434
from torchhd.tensors.map import MAPTensor
35-
from torchhd.tensors.fhrr import FHRRTensor
35+
from torchhd.tensors.fhrr import FHRRTensor, type_conversion as fhrr_type_conversion
3636
from torchhd.tensors.hrr import HRRTensor
3737
from torchhd.types import VSAOptions
3838

@@ -1017,7 +1017,6 @@ def __init__(
10171017
dtype=None,
10181018
requires_grad: bool = False,
10191019
) -> None:
1020-
factory_kwargs = {"device": device, "dtype": dtype}
10211020
super(FractionalPower, self).__init__()
10221021

10231022
self.in_features = in_features # data dimensions
@@ -1032,9 +1031,16 @@ def __init__(
10321031

10331032
self.vsa_tensor = functional.get_vsa_tensor_class(vsa)
10341033

1035-
if dtype not in self.vsa_tensor.supported_dtypes:
1034+
# If a specific dtype is specified make sure it is supported by the VSA model
1035+
if dtype != None and dtype not in self.vsa_tensor.supported_dtypes:
10361036
raise ValueError(f"dtype {dtype} not supported by {vsa}")
10371037

1038+
# The internal weights/phases are stored as floats even if the output is a complex tensor
1039+
if dtype != None and vsa == "FHRR":
1040+
dtype = fhrr_type_conversion[dtype]
1041+
1042+
factory_kwargs = {"device": device, "dtype": dtype}
1043+
10381044
# If the distribution is a string use the presets in predefined_kernels
10391045
if isinstance(distribution, str):
10401046
try:

torchhd/tests/test_embeddings.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
#
2424
import pytest
2525
import torch
26+
import math
2627

2728
import torchhd
2829
from torchhd import functional
2930
from torchhd import embeddings
3031
from torchhd.tensors.hrr import HRRTensor
32+
from torchhd.tensors.fhrr import type_conversion as fhrr_type_conversion
3133

3234

3335
from .utils import (
@@ -540,3 +542,111 @@ def test_value(self, vsa):
540542
)
541543
> 0.99
542544
)
545+
546+
547+
class TestFractionalPower:
548+
@pytest.mark.parametrize("vsa", vsa_tensors)
549+
def test_default_dtype(self, vsa):
550+
dimensions = 1000
551+
embedding = 10
552+
553+
if vsa not in {"HRR", "FHRR"}:
554+
with pytest.raises(ValueError):
555+
embeddings.FractionalPower(embedding, dimensions, vsa=vsa)
556+
557+
return
558+
559+
emb = embeddings.FractionalPower(embedding, dimensions, vsa=vsa)
560+
x = torch.randn(2, embedding)
561+
y = emb(x)
562+
assert y.shape == (2, dimensions)
563+
564+
if vsa == "HRR":
565+
assert y.dtype == torch.float32
566+
elif vsa == "FHRR":
567+
assert y.dtype == torch.complex64
568+
else:
569+
return
570+
571+
@pytest.mark.parametrize("dtype", torch_dtypes)
572+
def test_dtype(self, dtype):
573+
dimensions = 1456
574+
embedding = 2
575+
576+
if dtype not in {torch.float32, torch.float64}:
577+
with pytest.raises(ValueError):
578+
embeddings.FractionalPower(
579+
embedding, dimensions, vsa="HRR", dtype=dtype
580+
)
581+
else:
582+
emb = embeddings.FractionalPower(
583+
embedding, dimensions, vsa="HRR", dtype=dtype
584+
)
585+
586+
x = torch.randn(13, embedding, dtype=dtype)
587+
y = emb(x)
588+
assert y.shape == (13, dimensions)
589+
assert y.dtype == dtype
590+
591+
if dtype not in {torch.complex64, torch.complex128}:
592+
with pytest.raises(ValueError):
593+
embeddings.FractionalPower(
594+
embedding, dimensions, vsa="FHRR", dtype=dtype
595+
)
596+
else:
597+
emb = embeddings.FractionalPower(
598+
embedding, dimensions, vsa="FHRR", dtype=dtype
599+
)
600+
601+
x = torch.randn(13, embedding, dtype=fhrr_type_conversion[dtype])
602+
y = emb(x)
603+
assert y.shape == (13, dimensions)
604+
assert y.dtype == dtype
605+
606+
def test_device(self):
607+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
608+
609+
emb = embeddings.FractionalPower(35, 1000, "gaussian", device=device)
610+
611+
x = torchhd.random(5, 35, device=device)
612+
y = emb(x)
613+
assert y.shape == (5, 1000)
614+
assert y.device.type == device.type
615+
616+
def test_custom_dist_iid(self):
617+
kernel_shape = torch.distributions.Normal(0, 1)
618+
band = 3.0
619+
620+
emb = embeddings.FractionalPower(3, 1000, kernel_shape, band)
621+
x = torch.randn(1, 3)
622+
y = emb(x)
623+
assert y.shape == (1, 1000)
624+
625+
def test_custom_dist_2d(self):
626+
# Phase distribution for periodic Sinc kernel
627+
class HexDisc(torch.distributions.Categorical):
628+
def __init__(self):
629+
super().__init__(torch.ones(6))
630+
self.r = 1
631+
self.side = self.r * math.sqrt(3) / 2
632+
self.phases = torch.tensor(
633+
[
634+
[-self.r, 0.0],
635+
[-self.r / 2, self.side],
636+
[self.r / 2, self.side],
637+
[self.r, 0.0],
638+
[self.r / 2, -self.side],
639+
[-self.r / 2, -self.side],
640+
]
641+
)
642+
643+
def sample(self, sample_shape=torch.Size()):
644+
return self.phases[super().sample(sample_shape), :]
645+
646+
kernel_shape = HexDisc()
647+
band = 3.0
648+
649+
emb = embeddings.FractionalPower(2, 1000, kernel_shape, band)
650+
x = torch.randn(5, 2)
651+
y = emb(x)
652+
assert y.shape == (5, 1000)

0 commit comments

Comments
 (0)