Skip to content

Commit 452120f

Browse files
VolodyaCOVladimir Vargas Calderónthisackevinchern
authored
Add discrete variational autoencoder (#8)
* [feature] autoencoder architecture with gumbel-trick model * autoencoder architecture * dwave-hybrid is needed for the rbm to be imported * encoder and decoder architectures. main training loop from scratch * typing fixes * spins -> discrete renaming * kl divergence * updated requirements for examples to work * when encoding data into spin strings, an arbitrary number of spin strings (n_samples) per data point is now allowed * decoder is now aware of many spin strings per datapoint * annealing learning rate * fixing shapes for generating images * fixing shapes into and out of decoder when generating images * improved docstrings * remove mmd loss * Deleting example but keeping it locally * improved docstrings * renaming file. consistently using DiscreteAutoEncoder. citing paper. docstrings added * importing from all * licence header * test autoencoder * type hint fixed * docstrings improvement * using functional BCE with logits definition * docstrings improve,memt * cited DVAE paper * updated previous GRBM to new one * Remove example requirements * Remove unrequired arg in objective call * Add unittests * Rename DVAE and remove numpy from tests * Remove duplicate files include author in filename * Address minor PR review comments * Apply suggestions from code review Co-authored-by: Theodor Isacsson <[email protected]> * Apply suggestions from code review --------- Co-authored-by: Vladimir Vargas Calderón <[email protected]> Co-authored-by: Theodor Isacsson <[email protected]> Co-authored-by: kchern <[email protected]> Co-authored-by: Kevin Chern <[email protected]>
1 parent 144c9bc commit 452120f

File tree

9 files changed

+443
-24
lines changed

9 files changed

+443
-24
lines changed

dwave/plugins/torch/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
# limitations under the License.
1414
#
1515

16+
from dwave.plugins.torch.models.discrete_variational_autoencoder import *
1617
from dwave.plugins.torch.models.boltzmann_machine import *

dwave/plugins/torch/models/boltzmann_machine.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,21 @@ def _setup_hidden(self):
133133
"""Preprocess some indexes to enable vectorized computation of effective fields of hidden
134134
units."""
135135
self._connected_hidden = any(
136-
a in self.hidden_nodes and b in self.hidden_nodes for a, b in self.edges)
136+
a in self.hidden_nodes and b in self.hidden_nodes for a, b in self.edges
137+
)
137138
if self._connected_hidden:
138-
err_message = "Current implementation does not support intrahidden-unit connections."
139+
err_message = (
140+
"Current implementation does not support intrahidden-unit connections."
141+
)
139142
raise NotImplementedError(err_message)
140143

141-
visible_idx = torch.tensor([self._node_to_idx[v]
142-
for v in self._nodes if v not in self.hidden_nodes], dtype=int)
144+
visible_idx = torch.tensor(
145+
[self._node_to_idx[v] for v in self._nodes if v not in self.hidden_nodes],
146+
dtype=int,
147+
)
143148
hidden_idx = torch.tensor(
144-
[i for i in torch.arange(self._n_nodes) if i not in visible_idx], dtype=int)
149+
[i for i in torch.arange(self._n_nodes) if i not in visible_idx], dtype=int
150+
)
145151
self.register_buffer("_visible_idx", visible_idx)
146152
self.register_buffer("_hidden_idx", hidden_idx)
147153

@@ -317,7 +323,8 @@ def sample(
317323
return sample_set
318324

319325
def sampleset_to_tensor(
320-
self, sample_set: SampleSet, device: Optional[torch.device] = None) -> torch.Tensor:
326+
self, sample_set: SampleSet, device: Optional[torch.device] = None
327+
) -> torch.Tensor:
321328
"""Converts a ``dimod.SampleSet`` to a ``torch.Tensor`` using the node order of the class.
322329
323330
Args:
@@ -342,7 +349,7 @@ def quasi_objective(
342349
linear_range: Optional[tuple[float, float]] = None,
343350
quadratic_range: Optional[tuple[float, float]] = None,
344351
sampler: Optional[Sampler] = None,
345-
sample_kwargs: Optional[dict] = None
352+
sample_kwargs: Optional[dict] = None,
346353
) -> torch.Tensor:
347354
"""A quasi-objective function with gradients equivalent to the gradients of the
348355
negative log likelihood.
@@ -432,20 +439,20 @@ def _compute_effective_field(self, padded: torch.Tensor) -> torch.Tensor:
432439
contribution = padded[:, self._flat_adj] * self._quadratic[self._flat_j_idx]
433440
cumulative_contribution = contribution.cumsum(1)
434441
# Don't forget to add the linear fields!
435-
h_eff = self._linear[self.hidden_idx] + cumulative_contribution[:, self._bin_idx].diff(
436-
dim=1, prepend=torch.zeros(bs, device=padded.device).unsqueeze(1)
437-
)
442+
h_eff = self._linear[self.hidden_idx] + cumulative_contribution[
443+
:, self._bin_idx
444+
].diff(dim=1, prepend=torch.zeros(bs, device=padded.device).unsqueeze(1))
438445

439446
return h_eff
440447

441448
def _approximate_expectation_sampling(
442-
self,
443-
obs: torch.Tensor,
444-
sampler: Sampler,
445-
prefactor: float,
446-
linear_range: Optional[tuple[float, float]] = None,
447-
quadratic_range: Optional[tuple[float, float]] = None,
448-
sample_kwargs: Optional[dict] = None
449+
self,
450+
obs: torch.Tensor,
451+
sampler: Sampler,
452+
prefactor: float,
453+
linear_range: Optional[tuple[float, float]] = None,
454+
quadratic_range: Optional[tuple[float, float]] = None,
455+
sample_kwargs: Optional[dict] = None,
449456
) -> torch.Tensor:
450457
"""Approximate expectation of hidden units via sampling.
451458
@@ -471,8 +478,11 @@ def _approximate_expectation_sampling(
471478
"""
472479
# Create the BQM and remove visible units
473480
bqm = BinaryQuadraticModel.from_ising(
474-
*self.to_ising(prefactor, linear_range, quadratic_range))
475-
bqm.remove_variables_from([self.idx_to_node[vidx] for vidx in self.visible_idx.tolist()])
481+
*self.to_ising(prefactor, linear_range, quadratic_range)
482+
)
483+
bqm.remove_variables_from(
484+
[self.idx_to_node[vidx] for vidx in self.visible_idx.tolist()]
485+
)
476486

477487
# Compute the effective fields for hidden units
478488
padded = self._pad(obs)
@@ -524,8 +534,10 @@ def _compute_expectation_disconnected(self, obs: torch.Tensor) -> torch.Tensor:
524534
variables in the model, i.e., number of hidden and visible units.
525535
"""
526536
if self._connected_hidden:
527-
err_msg = ("`_compute_expectation_disconnected` is not applicable when edges exist "
528-
"between hidden units.")
537+
err_msg = (
538+
"`_compute_expectation_disconnected` is not applicable when edges exist "
539+
"between hidden units."
540+
)
529541
raise ValueError(err_msg)
530542
m = self._pad(obs)
531543
h_eff = self._compute_effective_field(m)
@@ -592,8 +604,12 @@ def sufficient_statistics(self, x: torch.Tensor) -> torch.Tensor:
592604
interactions = self.interactions(x)
593605
return torch.cat([x, interactions], 1)
594606

595-
def to_ising(self, prefactor: float, linear_range: Optional[tuple[float, float]] = None,
596-
quadratic_range: Optional[tuple[float, float]] = None) -> tuple[dict, dict]:
607+
def to_ising(
608+
self,
609+
prefactor: float,
610+
linear_range: Optional[tuple[float, float]] = None,
611+
quadratic_range: Optional[tuple[float, float]] = None,
612+
) -> tuple[dict, dict]:
597613
"""Convert the model to Ising format.
598614
599615
Convert the model to Ising format with scaling (``prefactor``) followed by clipping (if
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2025 D-Wave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# The use of the discrete autoencoder implementations below (including the
16+
# DiscreteVariationalAutoencoder) with a quantum computing system is
17+
# protected by the intellectual property rights of D-Wave Quantum Inc.
18+
# and its affiliates.
19+
#
20+
# The use of the discrete autoencoder implementations below (including the
21+
# DiscreteVariationalAutoencoder) with D-Wave's quantum computing
22+
# system will require access to D-Wave’s LeapTM quantum cloud service and
23+
# will be governed by the Leap Cloud Subscription Agreement available at:
24+
# https://cloud.dwavesys.com/leap/legal/cloud_subscription_agreement/
25+
#
26+
27+
from collections.abc import Callable
28+
from typing import Optional
29+
30+
import torch
31+
32+
__all__ = ["DiscreteVariationalAutoencoder"]
33+
34+
35+
class DiscreteVariationalAutoencoder(torch.nn.Module):
36+
"""DiscreteAutoEncoder architecture amenable for training discrete models as priors.
37+
See https://iopscience.iop.org/article/10.1088/2632-2153/aba220
38+
39+
Such discrete models include spin-variable models amenable for the QPU. This
40+
architecture is a modification of the standard autoencoder architecture, where
41+
the encoder outputs a latent representation of the data, and the decoder
42+
reconstructs the data from the latent representation. In our case, there is an
43+
additional step where the latent representation is mapped to a discrete
44+
representation, which is then passed to the decoder.
45+
46+
Args:
47+
encoder (torch.nn.Module): The encoder must output latents that are later on
48+
passed to ``latent_to_discrete``. An encoder has signature (x) -> l. x has
49+
shape (batch_size, f1, f2, ...) and l has shape (batch_size, l1, l2, ...).
50+
decoder (torch.nn.Module): Decodes discrete tensors into data tensors. A decoder
51+
has signature (d) -> x'. d has shape (batch_size, n, d1, d2, ...) and x' has
52+
shape (batch_size, f'1, f'2, ...); if x' is the reconstructed data then
53+
fi=f'i, but x' might be another representation of the data (e.g. in a
54+
text-to-image model, x is a sequence of tokens, and x' is an image). Note
55+
that the decoder input is of shape (batch_size, n, d1, d2, ...), where n is
56+
a number of discrete representations to be created from a single latent
57+
representation of a single initial data point.
58+
latent_to_discrete (Callable[[torch.Tensor, int], torch.Tensor] | None): A
59+
stochastic and differentiable function that maps the output of the encoder
60+
to a discrete representation (a function is deterministic by definition;
61+
here "stochastic" means the function implicitly takes an additional noise
62+
variables as input). Importantly, since the function is stochastic, it
63+
allows for the creation of multiple discrete representations from the latent
64+
representation of a single data point. Thus, the signature of this function
65+
is (l, n) -> d, where l is the output of the encoder and has shape
66+
(batch_size, l1, l2, ...), n is the number of discrete representations per
67+
data point, and d has shape (batch_size, n, d1, d2, ...), which will be the
68+
input to the decoder. If None, the gumbel softmax function is used for
69+
stochasticity. Defaults to None.
70+
"""
71+
72+
def __init__(
73+
self,
74+
encoder: torch.nn.Module,
75+
decoder: torch.nn.Module,
76+
latent_to_discrete: Optional[Callable[[torch.Tensor, int], torch.Tensor]] = None,
77+
):
78+
super().__init__()
79+
self._encoder = encoder
80+
self._decoder = decoder
81+
if latent_to_discrete is None:
82+
83+
def latent_to_discrete(
84+
logits: torch.Tensor, n_samples: int
85+
) -> torch.Tensor:
86+
# Logits is of shape (batch_size, n_discrete), we assume these logits
87+
# refer to the probability of each discrete variable being 1. To use the
88+
# gumbel softmax function we need to reshape the logits to (batch_size,
89+
# n_discrete, 1), and then stack the logits to a zeros tensor of the
90+
# same shape. This is done to ensure that the gumbel softmax function
91+
# works correctly.
92+
93+
logits = logits.unsqueeze(-1)
94+
logits = torch.cat((logits, torch.zeros_like(logits)), dim=-1)
95+
# We now create a new leading dimension and repeat the logits n_samples
96+
# times:
97+
logits = logits.unsqueeze(1).repeat(1, n_samples, 1, 1)
98+
one_hots = torch.nn.functional.gumbel_softmax(
99+
logits, tau=1 / 7, hard=True
100+
)
101+
# The constant 1/7 is used because it was used in
102+
# https://iopscience.iop.org/article/10.1088/2632-2153/aba220
103+
104+
# one_hots is of shape (batch_size, n_samples, n_discrete, 2), we need
105+
# to take the first element of the last dimension and convert it to spin
106+
# variables to make the latent space compatible with QPU models.
107+
return one_hots[..., 0] * 2 - 1
108+
109+
self._latent_to_discrete = latent_to_discrete
110+
111+
@property
112+
def encoder(self):
113+
"""Encoder network that maps image data to latent spinstrings."""
114+
return self._encoder
115+
116+
@property
117+
def decoder(self):
118+
"""Decoder network that maps latent variables to images."""
119+
return self._decoder
120+
121+
@property
122+
def latent_to_discrete(self):
123+
"""Function that maps the output of the encoder to a discrete representation"""
124+
return self._latent_to_discrete
125+
126+
def forward(
127+
self, x: torch.Tensor, n_samples: int = 1
128+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
129+
"""Ingests data into the :class:`DiscreteVariationalAutoencoder`.
130+
131+
Args:
132+
x (torch.Tensor): Input data of shape (batch_size, ...).
133+
n_samples (int, optional): Since the ``latent_to_discrete`` map is, in
134+
general, stochastic (see :class:`DiscreteVariationalAutoencoder` for more on this),
135+
several different discrete samples can be obtained by applying this map
136+
to the same encoded data point. This argument specifies how many such
137+
samples are obtained. Defaults to 1.
138+
139+
Returns:
140+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The reconstructed data of
141+
shape (batch_size, n_samples, ...), the discrete representation(s) of the
142+
encoded data with the shape (batch_size, n_samples, ...), and the logits,
143+
which are the encoded data of shape (batch_size, ...).
144+
"""
145+
latents = self.encoder(x)
146+
discretes = self.latent_to_discrete(latents, n_samples)
147+
xhat = self.decoder(discretes)
148+
return latents, discretes, xhat
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 D-Wave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dwave.plugins.torch.models.losses.kl_divergence import *
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2025 D-Wave
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from dimod import Sampler
19+
20+
from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
21+
22+
__all__ = ["pseudo_kl_divergence_loss"]
23+
24+
25+
def pseudo_kl_divergence_loss(
26+
spins: torch.Tensor,
27+
logits: torch.Tensor,
28+
boltzmann_machine: GraphRestrictedBoltzmannMachine,
29+
sampler: Sampler,
30+
sample_kwargs: dict,
31+
prefactor: Optional[float] = None,
32+
linear_range: Optional[tuple[float, float]] = None,
33+
quadratic_range: Optional[tuple[float, float]] = None,
34+
):
35+
"""A pseudo Kullback-Leibler divergence loss function for a discrete autoencoder with a
36+
Boltzmann machine prior.
37+
38+
This is not the true KL divergence, but the gradient of this function is the same as
39+
the KL divergence gradient. See https://arxiv.org/abs/1609.02200 for more details.
40+
41+
Args:
42+
spins (torch.Tensor): A tensor of spins of shape (batch_size, n_spins) or shape
43+
(batch_size, n_samples, n_spins) obtained from a stochastic function that
44+
maps the output of the encoder (logit representation) to a spin
45+
representation.
46+
logits (torch.Tensor): A tensor of logits of shape (batch_size, n_spins). These
47+
logits are the raw output of the encoder.
48+
boltzmann_machine (GraphRestrictedBoltzmannMachine): An instance of a Boltzmann
49+
machine.
50+
sampler (Sampler): A sampler used for generating samples.
51+
sample_kwargs (dict): Additional keyword arguments for the ``sampler.sample``
52+
method.
53+
prefactor (float, optional): A scaling applied to the Hamiltonian weights
54+
(linear and quadratic weights). When None, no scaling is applied. Defaults
55+
to None.
56+
linear_range (tuple[float, float], optional): Linear weights are clipped to
57+
``linear_range`` prior to sampling. This clipping occurs after the
58+
``prefactor`` scaling has been applied. When None, no clipping is applied.
59+
Defaults to None.
60+
quadratic_range (tuple[float, float], optional): Quadratic weights are clipped
61+
to ``quadratic_range`` prior to sampling. This clipping occurs after the
62+
``prefactor`` scaling has been applied. When None, no clipping is applied.
63+
Defaults to None.
64+
65+
Returns:
66+
torch.Tensor: The computed pseudo KL divergence loss.
67+
"""
68+
samples = boltzmann_machine.sample(
69+
sampler=sampler,
70+
device=spins.device,
71+
prefactor=prefactor if prefactor is not None else 1.0,
72+
linear_range=linear_range,
73+
quadratic_range=quadratic_range,
74+
sample_params=sample_kwargs,
75+
)
76+
probabilities = torch.sigmoid(logits)
77+
entropy = torch.nn.functional.binary_cross_entropy_with_logits(logits, probabilities)
78+
cross_entropy = boltzmann_machine.quasi_objective(spins, samples)
79+
pseudo_kl_divergence = cross_entropy - entropy
80+
return pseudo_kl_divergence

dwave/plugins/torch/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323

2424
def sampleset_to_tensor(
25-
ordered_vars: list, sample_set: SampleSet, device: torch.device = None) -> torch.Tensor:
25+
ordered_vars: list, sample_set: SampleSet, device: Optional[torch.device] = None
26+
) -> torch.Tensor:
2627
"""Converts a ``dimod.SampleSet`` to a ``torch.Tensor``.
2728
2829
Args:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"networkx",
3232
"dimod",
3333
"dwave-system",
34+
"dwave-hybrid",
3435
]
3536

3637
[project.readme]

tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
coverage
22
codecov
3+
parameterized

0 commit comments

Comments
 (0)