Skip to content

Commit a711a59

Browse files
Implement Binary Sparse Block Codes VSA model (#146)
* WIP segmented sparse vectors * [github-action] formatting fixes * Add documentation and rename * [github-action] formatting fixes * Update comments * Rename to Sparse Block Codes * Add tests and change segment size to block size * [github-action] formatting fixes * Add missing block_size argument * [github-action] formatting fixes * Update docs * [github-action] formatting fixes * Update naming to Binary SBC * [github-action] formatting fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f6fdce8 commit a711a59

17 files changed

+1010
-194
lines changed

docs/torchhd.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ VSA Models
8585
MAPTensor
8686
HRRTensor
8787
FHRRTensor
88+
BSBCTensor
8889
VTBTensor
8990

9091

torchhd/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torchhd.tensors.map import MAPTensor
3535
from torchhd.tensors.hrr import HRRTensor
3636
from torchhd.tensors.fhrr import FHRRTensor
37+
from torchhd.tensors.bsbc import BSBCTensor
3738
from torchhd.tensors.vtb import VTBTensor
3839

3940
from torchhd.functional import (
@@ -85,6 +86,7 @@
8586
"MAPTensor",
8687
"HRRTensor",
8788
"FHRRTensor",
89+
"BSBCTensor",
8890
"VTBTensor",
8991
"functional",
9092
"embeddings",

torchhd/embeddings.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
sparse: bool = False,
109109
device=None,
110110
dtype=None,
111+
**kwargs,
111112
) -> None:
112113
factory_kwargs = {"device": device, "dtype": dtype}
113114
# Have to call Module init explicitly in order not to use the Embedding init
@@ -116,6 +117,7 @@ def __init__(
116117
self.num_embeddings = num_embeddings
117118
self.embedding_dim = embedding_dim
118119
self.vsa = vsa
120+
self.vsa_kwargs = kwargs
119121

120122
if padding_idx is not None:
121123
if padding_idx > 0:
@@ -135,7 +137,7 @@ def __init__(
135137
self.sparse = sparse
136138

137139
embeddings = functional.empty(
138-
num_embeddings, embedding_dim, self.vsa, **factory_kwargs
140+
num_embeddings, embedding_dim, self.vsa, **factory_kwargs, **self.vsa_kwargs
139141
)
140142
# Have to provide requires grad at the creation of the parameters to
141143
# prevent errors when instantiating a non-float embedding
@@ -148,7 +150,11 @@ def reset_parameters(self) -> None:
148150

149151
with torch.no_grad():
150152
embeddings = functional.empty(
151-
self.num_embeddings, self.embedding_dim, self.vsa, **factory_kwargs
153+
self.num_embeddings,
154+
self.embedding_dim,
155+
self.vsa,
156+
**factory_kwargs,
157+
**self.vsa_kwargs,
152158
)
153159
self.weight.copy_(embeddings)
154160

@@ -214,6 +220,7 @@ def __init__(
214220
sparse: bool = False,
215221
device=None,
216222
dtype=None,
223+
**kwargs,
217224
) -> None:
218225
factory_kwargs = {"device": device, "dtype": dtype}
219226
# Have to call Module init explicitly in order not to use the Embedding init
@@ -222,6 +229,7 @@ def __init__(
222229
self.num_embeddings = num_embeddings
223230
self.embedding_dim = embedding_dim
224231
self.vsa = vsa
232+
self.vsa_kwargs = kwargs
225233

226234
if padding_idx is not None:
227235
if padding_idx > 0:
@@ -241,7 +249,7 @@ def __init__(
241249
self.sparse = sparse
242250

243251
embeddings = functional.identity(
244-
num_embeddings, embedding_dim, self.vsa, **factory_kwargs
252+
num_embeddings, embedding_dim, self.vsa, **factory_kwargs, **self.vsa_kwargs
245253
)
246254
# Have to provide requires grad at the creation of the parameters to
247255
# prevent errors when instantiating a non-float embedding
@@ -254,7 +262,11 @@ def reset_parameters(self) -> None:
254262

255263
with torch.no_grad():
256264
embeddings = functional.identity(
257-
self.num_embeddings, self.embedding_dim, self.vsa, **factory_kwargs
265+
self.num_embeddings,
266+
self.embedding_dim,
267+
self.vsa,
268+
**factory_kwargs,
269+
**self.vsa_kwargs,
258270
)
259271
self.weight.copy_(embeddings)
260272

@@ -266,7 +278,7 @@ def _fill_padding_idx_with_empty(self) -> None:
266278
if self.padding_idx is not None:
267279
with torch.no_grad():
268280
empty = functional.empty(
269-
1, self.embedding_dim, self.vsa, **factory_kwargs
281+
1, self.embedding_dim, self.vsa, **factory_kwargs, **self.vsa_kwargs
270282
)
271283
self.weight[self.padding_idx].copy_(empty.squeeze(0))
272284

@@ -332,6 +344,7 @@ def __init__(
332344
sparse: bool = False,
333345
device=None,
334346
dtype=None,
347+
**kwargs,
335348
) -> None:
336349
factory_kwargs = {"device": device, "dtype": dtype}
337350
# Have to call Module init explicitly in order not to use the Embedding init
@@ -340,6 +353,7 @@ def __init__(
340353
self.num_embeddings = num_embeddings
341354
self.embedding_dim = embedding_dim
342355
self.vsa = vsa
356+
self.vsa_kwargs = kwargs
343357

344358
if padding_idx is not None:
345359
if padding_idx > 0:
@@ -359,7 +373,7 @@ def __init__(
359373
self.sparse = sparse
360374

361375
embeddings = functional.random(
362-
num_embeddings, embedding_dim, self.vsa, **factory_kwargs
376+
num_embeddings, embedding_dim, self.vsa, **factory_kwargs, **self.vsa_kwargs
363377
)
364378
# Have to provide requires grad at the creation of the parameters to
365379
# prevent errors when instantiating a non-float embedding
@@ -372,7 +386,11 @@ def reset_parameters(self) -> None:
372386

373387
with torch.no_grad():
374388
embeddings = functional.random(
375-
self.num_embeddings, self.embedding_dim, self.vsa, **factory_kwargs
389+
self.num_embeddings,
390+
self.embedding_dim,
391+
self.vsa,
392+
**factory_kwargs,
393+
**self.vsa_kwargs,
376394
)
377395
self.weight.copy_(embeddings)
378396

@@ -384,7 +402,7 @@ def _fill_padding_idx_with_empty(self) -> None:
384402
if self.padding_idx is not None:
385403
with torch.no_grad():
386404
empty = functional.empty(
387-
1, self.embedding_dim, self.vsa, **factory_kwargs
405+
1, self.embedding_dim, self.vsa, **factory_kwargs, **self.vsa_kwargs
388406
)
389407
self.weight[self.padding_idx].copy_(empty.squeeze(0))
390408

@@ -469,6 +487,7 @@ def __init__(
469487
sparse: bool = False,
470488
device=None,
471489
dtype=None,
490+
**kwargs,
472491
) -> None:
473492
factory_kwargs = {"device": device, "dtype": dtype}
474493
# Have to call Module init explicitly in order not to use the Embedding init
@@ -477,6 +496,7 @@ def __init__(
477496
self.num_embeddings = num_embeddings
478497
self.embedding_dim = embedding_dim
479498
self.vsa = vsa
499+
self.vsa_kwargs = kwargs
480500
self.low = low
481501
self.high = high
482502
self.randomness = randomness
@@ -493,6 +513,7 @@ def __init__(
493513
self.vsa,
494514
randomness=randomness,
495515
**factory_kwargs,
516+
**self.vsa_kwargs,
496517
)
497518
# Have to provide requires grad at the creation of the parameters to
498519
# prevent errors when instantiating a non-float embedding
@@ -508,6 +529,7 @@ def reset_parameters(self) -> None:
508529
self.vsa,
509530
randomness=self.randomness,
510531
**factory_kwargs,
532+
**self.vsa_kwargs,
511533
)
512534
self.weight.copy_(embeddings)
513535

@@ -592,6 +614,7 @@ def __init__(
592614
sparse: bool = False,
593615
device=None,
594616
dtype=None,
617+
**kwargs,
595618
) -> None:
596619
factory_kwargs = {"device": device, "dtype": dtype}
597620
# Have to call Module init explicitly in order not to use the Embedding init
@@ -600,6 +623,7 @@ def __init__(
600623
self.num_embeddings = num_embeddings
601624
self.embedding_dim = embedding_dim
602625
self.vsa = vsa
626+
self.vsa_kwargs = kwargs
603627
self.low = low
604628
self.high = high
605629

@@ -610,7 +634,7 @@ def __init__(
610634
self.sparse = sparse
611635

612636
embeddings = functional.thermometer(
613-
num_embeddings, embedding_dim, self.vsa, **factory_kwargs
637+
num_embeddings, embedding_dim, self.vsa, **factory_kwargs, **self.vsa_kwargs
614638
)
615639
# Have to provide requires grad at the creation of the parameters to
616640
# prevent errors when instantiating a non-float embedding
@@ -621,7 +645,11 @@ def reset_parameters(self) -> None:
621645

622646
with torch.no_grad():
623647
embeddings = functional.thermometer(
624-
self.num_embeddings, self.embedding_dim, self.vsa, **factory_kwargs
648+
self.num_embeddings,
649+
self.embedding_dim,
650+
self.vsa,
651+
**factory_kwargs,
652+
**self.vsa_kwargs,
625653
)
626654
self.weight.copy_(embeddings)
627655

@@ -704,6 +732,7 @@ def __init__(
704732
sparse: bool = False,
705733
device=None,
706734
dtype=None,
735+
**kwargs,
707736
) -> None:
708737
factory_kwargs = {"device": device, "dtype": dtype}
709738
# Have to call Module init explicitly in order not to use the Embedding init
@@ -712,6 +741,7 @@ def __init__(
712741
self.num_embeddings = num_embeddings
713742
self.embedding_dim = embedding_dim
714743
self.vsa = vsa
744+
self.vsa_kwargs = kwargs
715745
self.phase = phase
716746
self.period = period
717747
self.randomness = randomness
@@ -728,6 +758,7 @@ def __init__(
728758
self.vsa,
729759
randomness=randomness,
730760
**factory_kwargs,
761+
**self.vsa_kwargs,
731762
)
732763
# Have to provide requires grad at the creation of the parameters to
733764
# prevent errors when instantiating a non-float embedding
@@ -743,6 +774,7 @@ def reset_parameters(self) -> None:
743774
self.vsa,
744775
randomness=self.randomness,
745776
**factory_kwargs,
777+
**self.vsa_kwargs,
746778
)
747779
self.weight.copy_(embeddings)
748780

@@ -945,6 +977,7 @@ def __init__(
945977
device=None,
946978
dtype=None,
947979
requires_grad: bool = False,
980+
**kwargs,
948981
):
949982
factory_kwargs = {
950983
"device": device,
@@ -954,10 +987,16 @@ def __init__(
954987
super(Density, self).__init__()
955988

956989
# A set of random vectors used as unique IDs for features of the dataset.
957-
self.key = Random(in_features, out_features, vsa, **factory_kwargs)
990+
self.key = Random(in_features, out_features, vsa, **factory_kwargs, **kwargs)
958991
# Thermometer encoding used for transforming input data.
959992
self.density_encoding = Thermometer(
960-
out_features + 1, out_features, vsa, low=low, high=high, **factory_kwargs
993+
out_features + 1,
994+
out_features,
995+
vsa,
996+
low=low,
997+
high=high,
998+
**factory_kwargs,
999+
**kwargs,
9611000
)
9621001

9631002
def reset_parameters(self) -> None:

torchhd/functional.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torchhd.tensors.map import MAPTensor
3333
from torchhd.tensors.hrr import HRRTensor
3434
from torchhd.tensors.fhrr import FHRRTensor
35+
from torchhd.tensors.bsbc import BSBCTensor
3536
from torchhd.tensors.vtb import VTBTensor
3637
from torchhd.types import VSAOptions
3738

@@ -83,6 +84,8 @@ def get_vsa_tensor_class(vsa: VSAOptions) -> Type[VSATensor]:
8384
return HRRTensor
8485
elif vsa == "FHRR":
8586
return FHRRTensor
87+
elif vsa == "BSBC":
88+
return BSBCTensor
8689
elif vsa == "VTB":
8790
return VTBTensor
8891

@@ -351,7 +354,10 @@ def level(
351354
dimensions,
352355
dtype=span_hv.dtype,
353356
device=span_hv.device,
354-
)
357+
).as_subclass(vsa_tensor)
358+
359+
if vsa == "BSBC":
360+
hv.block_size = span_hv.block_size
355361

356362
for i in range(num_vectors):
357363
span_idx = int(i // levels_per_span)
@@ -372,7 +378,7 @@ def level(
372378
hv[i] = torch.where(threshold_v[span_idx] < t, span_start_hv, span_end_hv)
373379

374380
hv.requires_grad = requires_grad
375-
return hv.as_subclass(vsa_tensor)
381+
return hv
376382

377383

378384
def thermometer(
@@ -461,7 +467,7 @@ def thermometer(
461467
device=rand_hv.device,
462468
)
463469
else:
464-
raise ValueError(f"{vsa_tensor} HD/VSA model is not defined.")
470+
raise ValueError(f"{vsa_tensor} HD/VSA model is not (yet) supported.")
465471

466472
# Create hypervectors using the obtained step
467473
for i in range(1, num_vectors):
@@ -575,7 +581,10 @@ def circular(
575581
dimensions,
576582
dtype=span_hv.dtype,
577583
device=span_hv.device,
578-
)
584+
).as_subclass(vsa_tensor)
585+
586+
if vsa == "BSBC":
587+
hv.block_size = span_hv.block_size
579588

580589
mutation_history = deque()
581590

@@ -618,7 +627,7 @@ def circular(
618627
hv[i // 2] = mutation_hv
619628

620629
hv.requires_grad = requires_grad
621-
return hv.as_subclass(vsa_tensor)
630+
return hv
622631

623632

624633
def bind(input: VSATensor, other: VSATensor) -> VSATensor:

0 commit comments

Comments
 (0)