Skip to content

Commit 272321f

Browse files
Fix wrong quantization example (#102)
* Fix wrong quantization example * [github-action] formatting fixes Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 92d3b4a commit 272321f

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

torchhd/functional.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,9 @@ def negative(input: VSA_Model) -> VSA_Model:
708708
def soft_quantize(input: Tensor):
709709
"""Applies the hyperbolic tanh function to all elements of the input tensor.
710710
711+
.. warning::
712+
This function does not take the VSA model class into account.
713+
711714
Args:
712715
input (Tensor): input tensor.
713716
@@ -717,12 +720,15 @@ def soft_quantize(input: Tensor):
717720
718721
Examples::
719722
720-
>>> x = functional.random_hv(2, 3)
721-
>>> y = functional.bundle(x[0], x[1])
723+
>>> x = torchhd.random_hv(2, 6)
724+
>>> x
725+
tensor([[ 1., 1., -1., 1., 1., 1.],
726+
[ 1., -1., -1., -1., 1., -1.]])
727+
>>> y = torchhd.bundle(x[0], x[1])
722728
>>> y
723-
tensor([0., 2., 0.])
724-
>>> functional.soft_quantize(y)
725-
tensor([0.0000, 0.9640, 0.0000])
729+
tensor([ 2., 0., -2., 0., 2., 0.])
730+
>>> torchhd.soft_quantize(y)
731+
tensor([ 0.9640, 0.0000, -0.9640, 0.0000, 0.9640, 0.0000])
726732
727733
"""
728734
return torch.tanh(input)
@@ -731,6 +737,9 @@ def soft_quantize(input: Tensor):
731737
def hard_quantize(input: Tensor):
732738
"""Applies binary quantization to all elements of the input tensor.
733739
740+
.. warning::
741+
This function does not take the VSA model class into account.
742+
734743
Args:
735744
input (Tensor): input tensor
736745
@@ -740,12 +749,15 @@ def hard_quantize(input: Tensor):
740749
741750
Examples::
742751
743-
>>> x = functional.random_hv(2, 3)
744-
>>> y = functional.bundle(x[0], x[1])
752+
>>> x = torchhd.random_hv(2, 6)
753+
>>> x
754+
tensor([[ 1., 1., -1., 1., 1., 1.],
755+
[ 1., -1., -1., -1., 1., -1.]])
756+
>>> y = torchhd.bundle(x[0], x[1])
745757
>>> y
746-
tensor([ 0., -2., -2.])
747-
>>> functional.hard_quantize(y)
748-
tensor([ 1., -1., -1.])
758+
tensor([ 2., 0., -2., 0., 2., 0.])
759+
>>> torchhd.hard_quantize(y)
760+
tensor([ 1., -1., -1., -1., 1., -1.])
749761
750762
"""
751763
# Make sure that the output tensor has the same dtype and device

0 commit comments

Comments
 (0)