|
23 | 23 | #
|
24 | 24 | import pytest
|
25 | 25 | import torch
|
| 26 | +import math |
26 | 27 |
|
27 | 28 | import torchhd
|
28 | 29 | from torchhd import functional
|
29 | 30 | from torchhd import embeddings
|
30 | 31 | from torchhd.tensors.hrr import HRRTensor
|
| 32 | +from torchhd.tensors.fhrr import type_conversion as fhrr_type_conversion |
31 | 33 |
|
32 | 34 |
|
33 | 35 | from .utils import (
|
@@ -540,3 +542,111 @@ def test_value(self, vsa):
|
540 | 542 | )
|
541 | 543 | > 0.99
|
542 | 544 | )
|
| 545 | + |
| 546 | + |
| 547 | +class TestFractionalPower: |
| 548 | + @pytest.mark.parametrize("vsa", vsa_tensors) |
| 549 | + def test_default_dtype(self, vsa): |
| 550 | + dimensions = 1000 |
| 551 | + embedding = 10 |
| 552 | + |
| 553 | + if vsa not in {"HRR", "FHRR"}: |
| 554 | + with pytest.raises(ValueError): |
| 555 | + embeddings.FractionalPower(embedding, dimensions, vsa=vsa) |
| 556 | + |
| 557 | + return |
| 558 | + |
| 559 | + emb = embeddings.FractionalPower(embedding, dimensions, vsa=vsa) |
| 560 | + x = torch.randn(2, embedding) |
| 561 | + y = emb(x) |
| 562 | + assert y.shape == (2, dimensions) |
| 563 | + |
| 564 | + if vsa == "HRR": |
| 565 | + assert y.dtype == torch.float32 |
| 566 | + elif vsa == "FHRR": |
| 567 | + assert y.dtype == torch.complex64 |
| 568 | + else: |
| 569 | + return |
| 570 | + |
| 571 | + @pytest.mark.parametrize("dtype", torch_dtypes) |
| 572 | + def test_dtype(self, dtype): |
| 573 | + dimensions = 1456 |
| 574 | + embedding = 2 |
| 575 | + |
| 576 | + if dtype not in {torch.float32, torch.float64}: |
| 577 | + with pytest.raises(ValueError): |
| 578 | + embeddings.FractionalPower( |
| 579 | + embedding, dimensions, vsa="HRR", dtype=dtype |
| 580 | + ) |
| 581 | + else: |
| 582 | + emb = embeddings.FractionalPower( |
| 583 | + embedding, dimensions, vsa="HRR", dtype=dtype |
| 584 | + ) |
| 585 | + |
| 586 | + x = torch.randn(13, embedding, dtype=dtype) |
| 587 | + y = emb(x) |
| 588 | + assert y.shape == (13, dimensions) |
| 589 | + assert y.dtype == dtype |
| 590 | + |
| 591 | + if dtype not in {torch.complex64, torch.complex128}: |
| 592 | + with pytest.raises(ValueError): |
| 593 | + embeddings.FractionalPower( |
| 594 | + embedding, dimensions, vsa="FHRR", dtype=dtype |
| 595 | + ) |
| 596 | + else: |
| 597 | + emb = embeddings.FractionalPower( |
| 598 | + embedding, dimensions, vsa="FHRR", dtype=dtype |
| 599 | + ) |
| 600 | + |
| 601 | + x = torch.randn(13, embedding, dtype=fhrr_type_conversion[dtype]) |
| 602 | + y = emb(x) |
| 603 | + assert y.shape == (13, dimensions) |
| 604 | + assert y.dtype == dtype |
| 605 | + |
| 606 | + def test_device(self): |
| 607 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 608 | + |
| 609 | + emb = embeddings.FractionalPower(35, 1000, "gaussian", device=device) |
| 610 | + |
| 611 | + x = torchhd.random(5, 35, device=device) |
| 612 | + y = emb(x) |
| 613 | + assert y.shape == (5, 1000) |
| 614 | + assert y.device.type == device.type |
| 615 | + |
| 616 | + def test_custom_dist_iid(self): |
| 617 | + kernel_shape = torch.distributions.Normal(0, 1) |
| 618 | + band = 3.0 |
| 619 | + |
| 620 | + emb = embeddings.FractionalPower(3, 1000, kernel_shape, band) |
| 621 | + x = torch.randn(1, 3) |
| 622 | + y = emb(x) |
| 623 | + assert y.shape == (1, 1000) |
| 624 | + |
| 625 | + def test_custom_dist_2d(self): |
| 626 | + # Phase distribution for periodic Sinc kernel |
| 627 | + class HexDisc(torch.distributions.Categorical): |
| 628 | + def __init__(self): |
| 629 | + super().__init__(torch.ones(6)) |
| 630 | + self.r = 1 |
| 631 | + self.side = self.r * math.sqrt(3) / 2 |
| 632 | + self.phases = torch.tensor( |
| 633 | + [ |
| 634 | + [-self.r, 0.0], |
| 635 | + [-self.r / 2, self.side], |
| 636 | + [self.r / 2, self.side], |
| 637 | + [self.r, 0.0], |
| 638 | + [self.r / 2, -self.side], |
| 639 | + [-self.r / 2, -self.side], |
| 640 | + ] |
| 641 | + ) |
| 642 | + |
| 643 | + def sample(self, sample_shape=torch.Size()): |
| 644 | + return self.phases[super().sample(sample_shape), :] |
| 645 | + |
| 646 | + kernel_shape = HexDisc() |
| 647 | + band = 3.0 |
| 648 | + |
| 649 | + emb = embeddings.FractionalPower(2, 1000, kernel_shape, band) |
| 650 | + x = torch.randn(5, 2) |
| 651 | + y = emb(x) |
| 652 | + assert y.shape == (5, 1000) |
0 commit comments