Skip to content

Conversation

@kevinchern
Copy link
Collaborator

@kevinchern kevinchern commented Nov 12, 2025

TODOs:

  • document
  • test
  • release note
  • wrap with store_config

@kevinchern kevinchern marked this pull request as draft November 17, 2025 21:25
@VolodyaCO VolodyaCO marked this pull request as ready for review November 28, 2025 16:28
@kevinchern
Copy link
Collaborator Author

kevinchern commented Nov 29, 2025

@VolodyaCO what's the motivation for the following changes?

  1. Removal of MMD as a module, and
  2. The addition of gradient-tracking in the get_bandwidth function.

@kevinchern
Copy link
Collaborator Author

kevinchern commented Dec 1, 2025

@VolodyaCO what's the motivation for the following changes?

1. Removal of MMD as a module, and

2. The addition of gradient-tracking in the `get_bandwidth` function.

Addressed in meeting. For posterity:

  • module was removed for consistency with pseudo_kl_divergence as functions.
  • get_bandwidth was accidentally removed (but the l2 matrix is detached in kernel function)

I will review ASAP

Copy link
Collaborator Author

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Vlad!! Nicely implemented and documented.
Let's add the MaximumMeanDiscrepancy a a module and this should be good to go.

return self._kernel(xy)


class RBFKernel(Kernel):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename to RadialBasisFunction

self.register_buffer("bandwidth_multipliers", bandwidth_multipliers)
self.bandwidth = bandwidth

def get_bandwidth(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make private and use @torch.nograd

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should cite the source that uses this bandwidth-selection heuristic(?)... I thought it was in the InfoVAE or Generative moment matching networks paper but can't find in either now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @VolodyaCO and @kevinchern! Test fail due to Python 3.9 tests being run (removing 3.9 support in #49, which should fix it).

Unit tests should be expanded. There should at least be test classes for kernels (RBF) and more unit tests for the mmd function. Otherwise, looks good. Just a few, mostly minor, comments.

Comment on lines +27 to +28
"""
Base class for kernels.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but to follow conventions in this and other packages.

Suggested change
"""
Base class for kernels.
"""Base class for kernels.

Same for other docstrings below.

Returns:
torch.Tensor: A (n, n) tensor.
"""
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for pass if there is a docstring.

Suggested change
pass

Comment on lines +59 to +60
y (torch.Tensor): A (n_y, f1, f2, ...) tensor.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y (torch.Tensor): A (n_y, f1, f2, ...) tensor.
Returns:
y (torch.Tensor): A (n_y, f1, f2, ...) tensor.
Returns:

Returns:
torch.Tensor: A (n_x + n_y, n_x + n_y) tensor.
"""
assert x.shape[1:] == y.shape[1:], (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be an if-check with an exception raised instead of an assert.

torch.Tensor | float: The base bandwidth parameter.
"""
if self.bandwidth is None:
assert l2_distance_matrix is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, raise an exception instead of assert.

__all__ = ["Kernel", "RBFKernel", "mmd_loss"]


class Kernel(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should Kernels (and RBF) be in a kernels.py instead of mmd.py?

soft = logits
result = hard - soft.detach() + soft
# Now we need to repeat the result n_samples times along a new dimension
return repeat(result, "b ... -> b n ...", n=n_samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we absolutely need repeat here? Seems a bit cumbersome to add einops as a test dependency just for this test. 🤔

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants