Skip to content

Commit b8f0460

Browse files
mikeheddesperevergesigordeoliveiranunes
authored
Support explicit VSA models enabling customization of operations (#91)
* Implement VSA Model class * Update README * Add tests for hypervector creation * test operations * Update tests * Create documentation outline * Fix short circuit tests * Update examples * Add automatic model conversion * WIP documentation * Refactor as_vsa_model function for automatic conversion * Update model references * Use backward supported typing * WIP BSC documentation * add documentation for the MAP VSA model * doc fhrr * Fix pytorch version install and empty BSC * Lower minimum torch version * Update examples * Increment version * Update examples * Deterministic tests Co-authored-by: pereverges <[email protected]> Co-authored-by: igordeoliveiranunes <[email protected]>
1 parent daeb6f0 commit b8f0460

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+3033
-1686
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,19 @@ usd, mxn = torchhd.random_hv(2, d) # US Dollar and Mexican Peso
6565

6666
# create country representations
6767
us_values = torch.stack([usa, wdc, usd])
68-
us = torchhd.functional.hash_table(keys, us_values)
68+
us = torchhd.hash_table(keys, us_values)
6969

7070
mx_values = torch.stack([mex, mxc, mxn])
71-
mx = torchhd.functional.hash_table(keys, mx_values)
71+
mx = torchhd.hash_table(keys, mx_values)
7272

7373
# combine all the associated information
74-
mx_us = torchhd.bind(us, mx)
74+
mx_us = torchhd.bind(torchhd.inverse(us), mx)
7575

7676
# query for the dollar of mexico
7777
usd_of_mex = torchhd.bind(mx_us, usd)
7878

7979
memory = torch.cat([keys, us_values, mx_values], dim=0)
80-
torchhd.functional.cosine_similarity(usd_of_mex, memory)
80+
torchhd.cos_similarity(usd_of_mex, memory)
8181
# tensor([-0.0062, 0.0123, -0.0057, -0.0019, -0.0084, -0.0078, 0.0102, 0.0057, 0.3292])
8282
# The hypervector for the Mexican Peso is the most similar.
8383
```

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch
1+
torch>=1.9.0
22
torchvision
33
pandas
44
requests

docs/_static/css/custom.css

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,4 +242,34 @@ body,
242242

243243
.wy-nav-content-wrap {
244244
background: none;
245+
}
246+
247+
.body,
248+
.rst-content .toctree-wrapper>p.caption,
249+
h1,
250+
h2,
251+
h3,
252+
h4,
253+
h5,
254+
h6,
255+
legend {
256+
font-family: -apple-system, BlinkMacSystemFont, arial, sans-serif !important;
257+
}
258+
259+
.body {
260+
font-synthesis: none;
261+
text-rendering: optimizeLegibility;
262+
-webkit-font-smoothing: antialiased;
263+
font-feature-settings: "liga", "case", "calt";
264+
}
265+
266+
.btn {
267+
border: None;
268+
box-shadow: None;
269+
padding: 8px 12px;
270+
}
271+
272+
.btn:active {
273+
box-shadow: None;
274+
padding: 8px 12px;
245275
}

docs/index.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Welcome to the Torchhd documentation!
22
=====================================
33

4-
Torchhd is a Python library dedicated to *Hyperdimensional Computing* (also knwon as *Vector Symbolic Architectures*).
4+
Torchhd is a Python library dedicated to *Hyperdimensional Computing* (also known as *Vector Symbolic Architectures*).
55

66
.. toctree::
77
:glob:
@@ -15,7 +15,7 @@ Torchhd is a Python library dedicated to *Hyperdimensional Computing* (also knwo
1515
:maxdepth: 2
1616
:caption: Package Reference:
1717

18-
functional
18+
torchhd
1919
embeddings
2020
structures
2121
datasets

docs/functional.rst renamed to docs/torchhd.rst

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
.. _functional:
22

3-
torchhd.functional
3+
torchhd
44
=========================
55

6-
.. currentmodule:: torchhd.functional
6+
.. currentmodule:: torchhd
77

88
This module consists of the basic hypervector generation functions and operations used on hypervectors.
99

@@ -14,6 +14,7 @@ Basis-hypervector sets
1414
:toctree: generated/
1515
:template: function.rst
1616

17+
empty_hv
1718
identity_hv
1819
random_hv
1920
level_hv
@@ -28,9 +29,10 @@ Operations
2829
:template: function.rst
2930

3031
bind
31-
unbind
3232
bundle
3333
permute
34+
inverse
35+
negative
3436
cleanup
3537
randsel
3638
multirandsel
@@ -44,9 +46,8 @@ Similarities
4446
.. autosummary::
4547
:toctree: generated/
4648
:template: function.rst
47-
4849

49-
cosine_similarity
50+
cos_similarity
5051
dot_similarity
5152
hamming_similarity
5253

@@ -68,14 +69,28 @@ Encodings
6869
graph
6970

7071

72+
VSA Models
73+
------------------------
74+
75+
.. autosummary::
76+
:toctree: generated/
77+
:template: class.rst
78+
79+
VSA_Model
80+
BSC
81+
MAP
82+
.. HRR
83+
FHRR
84+
85+
7186
Utilities
7287
------------------------
7388

7489
.. autosummary::
7590
:toctree: generated/
7691
:template: function.rst
77-
7892

93+
as_vsa_model
7994
map_range
8095
value_to_index
8196
index_to_value

examples/emg_hand_gestures.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torchmetrics
88
from tqdm import tqdm
99

10-
from torchhd import functional
10+
import torchhd
1111
from torchhd import embeddings
1212
from torchhd.datasets import EMGHandGestures
1313

@@ -40,12 +40,12 @@ def __init__(self, num_classes, timestamps, channels):
4040

4141
def encode(self, x: torch.Tensor) -> torch.Tensor:
4242
signal = self.signals(x)
43-
samples = functional.bind(signal, self.channels.weight.unsqueeze(0))
44-
samples = functional.bind(signal, self.timestamps.weight.unsqueeze(1))
43+
samples = torchhd.bind(signal, self.channels.weight.unsqueeze(0))
44+
samples = torchhd.bind(signal, self.timestamps.weight.unsqueeze(1))
4545

46-
samples = functional.multiset(samples)
47-
sample_hv = functional.ngrams(samples, n=N_GRAM_SIZE)
48-
return functional.hard_quantize(sample_hv)
46+
samples = torchhd.multiset(samples)
47+
sample_hv = torchhd.ngrams(samples, n=N_GRAM_SIZE)
48+
return torchhd.hard_quantize(sample_hv)
4949

5050
def forward(self, x: torch.Tensor) -> torch.Tensor:
5151
enc = self.encode(x)

examples/graphhd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from tqdm import tqdm
55

66
# Note: this example requires the torch_geometric library: https://pytorch-geometric.readthedocs.io
7-
import torch_geometric
7+
from torch_geometric.datasets import TUDataset
88

99
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
1010
import torchmetrics
1111

12-
from torchhd import functional
12+
import torchhd
1313
from torchhd import embeddings
1414

1515
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -20,7 +20,7 @@
2020
# for other available datasets see: https://pytorch-geometric.readthedocs.io/en/latest/notes/data_cheatsheet.html?highlight=tudatasets
2121
dataset = "MUTAG"
2222

23-
graphs = torch_geometric.datasets.TUDataset("data", dataset)
23+
graphs = TUDataset("../data", dataset)
2424
train_size = int(0.7 * len(graphs))
2525
test_size = len(graphs) - train_size
2626
train_ld, test_ld = torch.utils.data.random_split(graphs, [train_size, test_size])
@@ -98,8 +98,8 @@ def encode(self, x):
9898

9999
row, col = to_undirected(x.edge_index)
100100

101-
hvs = functional.bind(node_id_hvs[row], node_id_hvs[col])
102-
return functional.multiset(hvs)
101+
hvs = torchhd.bind(node_id_hvs[row], node_id_hvs[col])
102+
return torchhd.multiset(hvs)
103103

104104
def forward(self, x):
105105
enc = self.encode(x)

examples/hd_hashing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Note: this example requires the mmh3 library: https://github.com/hajimes/mmh3
44
import mmh3
55

6-
from torchhd import functional
6+
import torchhd
77

88

99
class HDHashing:
@@ -14,7 +14,7 @@ def __init__(self, levels: int, dimensions: int, device=None):
1414
self.dimensions = dimensions
1515
self.device = device
1616

17-
self.hvs = functional.circular_hv(levels, dimensions, device=device)
17+
self.hvs = torchhd.circular_hv(levels, dimensions, device=device)
1818
self.servers = []
1919
self.server_hvs = []
2020
self.weight_by_server = {}
@@ -35,7 +35,7 @@ def request(self, value: str):
3535
# The next three lines simulate associative memory in HDC
3636
# It returns the value at the memory location (server)
3737
# that is most similar to the requested location (request).
38-
similarity = functional.dot_similarity(hv, server_hvs)
38+
similarity = torchhd.dot_similarity(hv, server_hvs)
3939
server_idx = torch.argmax(similarity).item()
4040
return self.servers[server_idx]
4141

examples/language_recognition.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torchmetrics
88
from tqdm import tqdm
99

10-
from torchhd import functional
10+
import torchhd
1111
from torchhd import embeddings
1212
from torchhd.datasets import EuropeanLanguages as Languages
1313

@@ -65,8 +65,8 @@ def __init__(self, num_classes, size):
6565

6666
def encode(self, x):
6767
symbols = self.symbol(x)
68-
sample_hv = functional.ngrams(symbols, n=3)
69-
return functional.hard_quantize(sample_hv)
68+
sample_hv = torchhd.ngrams(symbols, n=3)
69+
return torchhd.hard_quantize(sample_hv)
7070

7171
def forward(self, x):
7272
enc = self.encode(x)

examples/mnist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torchmetrics
99
from tqdm import tqdm
1010

11-
from torchhd import functional
11+
import torchhd
1212
from torchhd import embeddings
1313

1414

@@ -43,9 +43,9 @@ def __init__(self, num_classes, size):
4343

4444
def encode(self, x):
4545
x = self.flatten(x)
46-
sample_hv = functional.bind(self.position.weight, self.value(x))
47-
sample_hv = functional.multiset(sample_hv)
48-
return functional.hard_quantize(sample_hv)
46+
sample_hv = torchhd.bind(self.position.weight, self.value(x))
47+
sample_hv = torchhd.multiset(sample_hv)
48+
return torchhd.hard_quantize(sample_hv)
4949

5050
def forward(self, x):
5151
enc = self.encode(x)

0 commit comments

Comments
 (0)