Skip to content

Conversation

@anahitamansouri
Copy link
Collaborator

This PR adds a Restricted Boltzmann Machine (RBM) model to the plugin. It includes:

  • Implementation of the RBM inboltzmann_machine.py
  • Tests in tests/test_boltzmann_machine.py
  • An example in examples/rbm_image_generation.py demonstrating the usage of this RBM for image generation on MNIST.

@anahitamansouri anahitamansouri self-assigned this Nov 27, 2025
@anahitamansouri anahitamansouri added the enhancement New feature or request label Nov 27, 2025
Copy link
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anahitamansouri. Overall, looks really great; very clean code, docstrings and tests. 🎉 A few comments:

  • Missing release notes entry.
  • Is this general enough to just be called an RBM, or should it be named something more specific? Also, could maybe expand the docstring a bit to explain what makes it different to the GRBM and what PCD is. The GRBM docstring could also be expanded a bit IMO, but that's a separate issue.
  • Both the init and the sampling methods differ quite significantly from the GRBM. Could/should they work similarly (e.g., name both sampling methods sample() and/or allow for more similar init signatures)?

I'll have a more thorough look at the tests but at a glance they look quite nice!

Comment on lines +688 to +700
# Initialize model parameters
# initialize weights
self._weights = torch.nn.Parameter(
0.1 * torch.randn(n_visible, n_hidden)
)
# initialize visible units biases.
self._visible_biases = torch.nn.Parameter(
0.5 * torch.ones(n_visible)
)
# initialize hidden units biases.
self._hidden_biases = torch.nn.Parameter(
0.5 * torch.ones(n_hidden)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do the 0.1, 0.5 and 0.5 values come from? Should they be hard-coded?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RE Where do the 0.1, 0.5 and 0.5 values come from? Should they be hard-coded?

I set those values arbitrarily for GRBM. There should be a better initialization scheme for RBMs, e.g., section 8

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These values worked for the image generation example. Sure, I can experiment with 0.01 as suggested in the guide and will let you know how it affects the performance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've found the initialisation of the GRBM to be not great for my experiments, so I've had to pass initial linear and quadratic weights. @kevinchern in your experience, have you had to do the same? If so, should we change the default initialisation?

Copy link
Collaborator

@kevinchern kevinchern Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VolodyaCO good point. I've had similar experiences and found setting initial weights to 0 to be robust in general. Could you create an issue for this?

edit: actually i'll do it now

edit 2: here's the issue. please add more details as u see fit #48

Comment on lines +725 to +753
@property
def visible_biases(self) -> torch.Tensor:
"""Visible biases of the RBM."""
return self._visible_biases

@property
def hidden_biases(self) -> torch.Tensor:
"""Hidden biases of the RBM."""
return self._hidden_biases

@property
def previous_visible_values(self) -> torch.Tensor:
"""Previous visible values used in Persistent Contrastive Divergence (PCD)."""
return self._previous_visible_values

@property
def weight_momenta(self) -> torch.Tensor:
"""Weight momenta of the RBM."""
return self._weight_momenta

@property
def visible_bias_momenta(self) -> torch.Tensor:
"""Visible bias momenta of the RBM."""
return self._visible_bias_momenta

@property
def hidden_bias_momenta(self) -> torch.Tensor:
"""Hidden bias momenta of the RBM."""
return self._hidden_bias_momenta
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all of these useful attributes to access for a user? Consider removing some of these properties if they're only used within the class and not useful for a general user.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. I added them because I thought this would be consistent with the GRBM. Sure, I'll keep the first two then.

Comment on lines +783 to +788
def generate_sample(
self,
batch_size: int,
gibbs_steps: int,
start_visible: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be named just sample instead, to conform with the GRBM class?

Comment on lines +797 to +798
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Comment on lines +829 to +832
"""
Perform one step of Contrastive Divergence (CD-k) with momentum and weight decay.
Uses Persistent Contrastive Divergence (PCD) by maintaining the last visible states
for Gibbs sampling across batches.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Perform one step of Contrastive Divergence (CD-k) with momentum and weight decay.
Uses Persistent Contrastive Divergence (PCD) by maintaining the last visible states
for Gibbs sampling across batches.
"""Perform one step of Contrastive Divergence (CD-k) with momentum and weight decay.
Uses Persistent Contrastive Divergence (PCD) by maintaining the last visible states
for Gibbs sampling across batches.

Comment on lines +71 to +75
# Datasets
data/*

# Generated images
samples/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about adding so generic folders to the gitignore. Are these only created when running the examples? In that case I'd either leave it up to the developer not to commit these or put them e.g., in examples/_data/ and examples/_samples/.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this directory is only created when running the examples. I wasn’t planning to add it to the .gitignore either, but I included it to get your input during the review. I’ll remove it then — thanks for the feedback!

Comment on lines +783 to +788
def generate_sample(
self,
batch_size: int,
gibbs_steps: int,
start_visible: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If start_visible != None then batch_size isn't required, right? You could make that optional as well unless there's a reasonable default value to use (e.g., batch_size=1).

Similarly, would it makes sense having gibbs_setps default to 1? I noticed that a test was using

hidden = RBM._sample_hidden()

which could in that case be written as

_, hidden = RBM.generate_sample()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding batch_size, you're right. I'll make it optional.
Regarding gibbs_steps, I’d prefer not to set a default value. I want users to make an explicit choice rather than unknowingly relying on a default of 1 (as often times we need more steps for our experiments). That test example just shows using 1 step with generate_sample is like generating with one _sample_hidden call.

Comment on lines +774 to +775
hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden)
representing the states of hidden units.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden)
representing the states of hidden units.
hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden)
representing the states of hidden units.

Comment on lines +759 to +765
Args:
visible (torch.Tensor): Tensor of shape (batch_size, n_visible)
representing the states of visible units.
Returns:
torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing
sampled hidden units.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Args should have indented linebreaks, returns should not.

Suggested change
Args:
visible (torch.Tensor): Tensor of shape (batch_size, n_visible)
representing the states of visible units.
Returns:
torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing
sampled hidden units.
Args:
visible (torch.Tensor): Tensor of shape (batch_size, n_visible)
representing the states of visible units.
Returns:
torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing
sampled hidden units.


return error

def forward(self, visible: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the same as the GRBM.

Suggested change
def forward(self, visible: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:

Copy link
Collaborator

@VolodyaCO VolodyaCO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good overall. Thank you.
The only big change I would request is to use the {-1,1} convention instead of {0,1}.

Comment on lines +805 to +807
visible_values = torch.randn(
batch_size, self.n_visible, device=self._weights.device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why random normal and not random uniform?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant why random normal and not random bernoulli?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hinton's RBM paper and tutorial use gaussian. It seems that Gaussian initialization improves stability. I actually experimented with Bernoulli and uniform initialization and they didn't work.

hidden_bias_grads, dim=0
)

with torch.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't all calculations in this method be wrapped in a no grad context?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the only part that is necessary to be in torch.no_grad is the parameters updates part. The rest can also be in it.

Comment on lines +819 to +828
def _contrastive_divergence(
self,
batch: torch.Tensor,
epoch: int,
n_gibbs_steps: int,
learning_rate: float,
momentum_coefficient: float,
weight_decay: float,
n_epochs: int,
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a public method. Also, from the docstrings, it was difficult for me to infer how to use this while training because of epoch and n_epochs (it isn't clear how this information is used: to compute a decayed learning rate). I think it would be a good addition to have an example in the docstring, something like

for epoch in range(n_epochs):
  for batch in dataloader:
    rbm.contrastive_divergence(batch, epoch, n_gibbs_steps, learning_rate, momentum_coefficient, weight_decay, n_epochs)

Copy link
Collaborator Author

@anahitamansouri anahitamansouri Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I will make it public. And, I guess it's good to add this to the docstring or what about referring to the example in rbm_image_generation.py where it's used in a real example?

class RestrictedBoltzmannMachine(torch.nn.Module):
"""A Restricted Boltzmann Machine (RBM) model.
This class defines the parameterization and inference of a binary RBM.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the {-1,1} convention instead of the {0,1} convention. I understand that {-1,1} is way more common in D-Wave's code. Am I wrong? @thisac

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion here. Happy to hear opinions on this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since both Vlad and Kevin suggested this, I would look into it.

Copy link
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good!! It's clearly documented and has sound logic. I have two main requests. The first is to decouple the three objects: model (RBM), sampler (PCD), and optimizer (e.g., momenta parameters.). We shouldn't have to implement the optimizer and can rely on any built-in pytorch optimizer. The second request is to implement Vlad's suggestion: convert the model to +/-1 values (instead of 0, 1). The underlying principle is "the RBM should be a drop-in replacement for the GRBM".

@anahitamansouri
Copy link
Collaborator Author

This is looking good!! It's clearly documented and has sound logic. I have two main requests. The first is to decouple the three objects: model (RBM), sampler (PCD), and optimizer (e.g., momenta parameters.). We shouldn't have to implement the optimizer and can rely on any built-in pytorch optimizer. The second request is to implement Vlad's suggestion: convert the model to +/-1 values (instead of 0, 1). The underlying principle is "the RBM should be a drop-in replacement for the GRBM".

Thanks for your review. These suggestions will drastically change the code. I will look into both of them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants