Skip to content

Commit 2daba0b

Browse files
authored
Add Hopfield and Attention memory models (#134)
* Add hopfield networks * Add docs * Torch 2.0.0 now implements prod for bfloat16 on CPU * Fix documentation references * Add Hopfield memory class * Fix documentation links * Add basic memory tests
1 parent 08fd6f8 commit 2daba0b

File tree

7 files changed

+473
-72
lines changed

7 files changed

+473
-72
lines changed

docs/conf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@
4545
# This pattern also affects html_static_path and html_extra_path.
4646
exclude_patterns = []
4747

48+
# maps functions with a class name that is indistinguishable when case is
49+
# ignore to another filename
50+
autosummary_filename_map = {
51+
"torchhd.memory.Hopfield": "hopfield-class",
52+
"torchhd.memory.hopfield": "hopfield-function",
53+
}
54+
4855

4956
# -- Options for HTML output -------------------------------------------------
5057

docs/getting_started.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Similar behavior can be achieved using the classes in the :ref:`embeddings` modu
5555
5656
weight = torch.tensor([149.0])
5757
# explicit mapping of the fruit weight to an index
58-
w_i = torchhd.functional.value_to_index(weight, 0, 200, 10)
58+
w_i = torchhd.value_to_index(weight, 0, 200, 10)
5959
weights[w_i] # select representation of 149
6060
6161
whereas the :ref:`embeddings<embeddings>` have this common behavior built-in:

docs/memory.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,12 @@ torchhd.memory
1010
:template: class.rst
1111

1212
SparseDistributed
13+
Hopfield
14+
15+
.. autosummary::
16+
:toctree: generated/
17+
:template: function.rst
18+
19+
hopfield
20+
modern_hopfield
21+
attention

torchhd/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def __call__(self, input: VSATensor, shifts: int = 1) -> VSATensor:
771771
def inverse(input: VSATensor) -> VSATensor:
772772
r"""Inverse for the binding operation.
773773
774-
See :func:`~torchhd.functional.bind`.
774+
See :func:`~torchhd.bind`.
775775
776776
Args:
777777
input (VSATensor): input hypervector
@@ -796,7 +796,7 @@ def inverse(input: VSATensor) -> VSATensor:
796796
def negative(input: VSATensor) -> VSATensor:
797797
r"""Inverse for the bundling operation.
798798
799-
See :func:`~torchhd.functional.bundle`.
799+
See :func:`~torchhd.bundle`.
800800
801801
Args:
802802
input (VSATensor): input hypervector

0 commit comments

Comments
 (0)