Skip to content

Commit 234edb9

Browse files
authored
Update tutorial with new (working) example code (#168)
Fixes #167
1 parent 549ad53 commit 234edb9

File tree

2 files changed

+51
-40
lines changed

2 files changed

+51
-40
lines changed

docs/classification.rst

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ HDC Learning
44
After learning about representing and manipulating information in hyperspace, we can implement our first HDC classification model! We will use as an example the famous MNIST dataset that contains images of handwritten digits.
55

66

7-
We start by importing Torchhd and any other libraries we need:
7+
We start by importing Torchhd and the other libraries we need, in addition to specifying the training parameters:
88

99
.. code-block:: python
1010
@@ -13,11 +13,22 @@ We start by importing Torchhd and any other libraries we need:
1313
import torch.nn.functional as F
1414
import torchvision
1515
from torchvision.datasets import MNIST
16+
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
1617
import torchmetrics
17-
18-
from torchhd import functional
18+
19+
import torchhd
20+
from torchhd.models import Centroid
1921
from torchhd import embeddings
2022
23+
# Use the GPU if available
24+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25+
print("Using {} device".format(device))
26+
27+
DIMENSIONS = 10000
28+
IMG_SIZE = 28
29+
NUM_LEVELS = 1000
30+
BATCH_SIZE = 1 # for GPUs with enough memory we can process multiple images at ones
31+
2132
Datasets
2233
--------
2334

@@ -34,55 +45,46 @@ Next, we load the training and testing datasets:
3445
test_ld = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
3546
3647
37-
In addition to the various datasets available in the Torch ecosystem, such as MNIST, the :ref:`datasets` module provides interface to several commonly used datasets in HDC. Such interfaces inherit from PyTorch's dataset class, ensuring interoperability with other datasets.
48+
In addition to the various datasets available in the Torch ecosystem, such as MNIST, the :ref:`datasets` module provides an interface to several commonly used datasets in HDC. Such interfaces inherit from PyTorch's dataset class, ensuring interoperability with other datasets.
3849

3950
Training
4051
--------
4152

42-
To perform the training, we start by defining a model. In addition to specifying the basis-hypervectors sets, the core part of the model is the encoding function. In the example below, we use random-hypervectors and level-hypervectors to encode the position and value of each pixel, respectively:
53+
To perform the training, we start by defining an encoding. In addition to specifying the basis-hypervectors sets, a core part of learning is the encoding function. In the example below, we use random-hypervectors and level-hypervectors to encode the position and value of each pixel, respectively:
4354

4455
.. code-block:: python
4556
46-
class Model(nn.Module):
47-
def __init__(self, num_classes, size):
48-
super(Model, self).__init__()
49-
50-
self.flatten = torch.nn.Flatten()
51-
52-
self.position = embeddings.Random(size * size, DIMENSIONS)
53-
self.value = embeddings.Level(NUM_LEVELS, DIMENSIONS)
54-
55-
self.classify = nn.Linear(DIMENSIONS, num_classes, bias=False)
56-
self.classify.weight.data.fill_(0.0)
57+
class Encoder(nn.Module):
58+
def __init__(self, out_features, size, levels):
59+
super(Encoder, self).__init__()
60+
self.flatten = torch.nn.Flatten()
61+
self.position = embeddings.Random(size * size, out_features)
62+
self.value = embeddings.Level(levels, out_features)
5763
58-
def encode(self, x):
59-
x = self.flatten(x)
60-
sample_hv = functional.bind(self.position.weight, self.value(x))
61-
sample_hv = functional.multiset(sample_hv)
62-
return functional.hard_quantize(sample_hv)
64+
def forward(self, x):
65+
x = self.flatten(x)
66+
sample_hv = torchhd.bind(self.position.weight, self.value(x))
67+
sample_hv = torchhd.multiset(sample_hv)
68+
return torchhd.hard_quantize(sample_hv)
6369
64-
def forward(self, x):
65-
enc = self.encode(x)
66-
logit = self.classify(enc)
67-
return logit
70+
encode = Encoder(DIMENSIONS, IMG_SIZE, NUM_LEVELS)
71+
encode = encode.to(device)
6872
69-
70-
model = Model(len(train_ds.classes), IMG_SIZE)
73+
num_classes = len(train_ds.classes)
74+
model = Centroid(DIMENSIONS, num_classes)
7175
model = model.to(device)
7276
73-
7477
Having defined the model, we iterate over the training samples to create the class-vectors:
7578

7679
.. code-block:: python
7780
78-
for samples, labels in train_ld:
79-
samples = samples.to(device)
80-
labels = labels.to(device)
81-
82-
samples_hv = model.encode(samples)
83-
model.classify.weight[labels] += samples_hv
81+
with torch.no_grad():
82+
for samples, labels in tqdm(train_ld, desc="Training"):
83+
samples = samples.to(device)
84+
labels = labels.to(device)
8485
85-
model.classify.weight[:] = F.normalize(model.classify.weight)
86+
samples_hv = encode(samples)
87+
model.add(samples_hv, labels)
8688
8789
Testing
8890
-------
@@ -91,9 +93,16 @@ With the model trained, we can classify the testing samples by encoding them and
9193

9294
.. code-block:: python
9395
94-
for samples, labels in test_ld:
95-
samples = samples.to(device)
96+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
97+
98+
with torch.no_grad():
99+
model.normalize()
100+
101+
for samples, labels in tqdm(test_ld, desc="Testing"):
102+
samples = samples.to(device)
103+
104+
samples_hv = encode(samples)
105+
outputs = model(samples_hv, dot=True)
106+
accuracy.update(outputs.cpu(), labels)
96107
97-
outputs = model(samples)
98-
predictions = torch.argmax(outputs, dim=-1)
99-
accuracy.update(predictions.cpu(), labels)
108+
print(f"Testing accuracy of {(accuracy.compute().item() * 100):.3f}%")

docs/docutils.conf

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[restructuredtext parser]
2+
tab_width: 4

0 commit comments

Comments
 (0)