-
Notifications
You must be signed in to change notification settings - Fork 9
Add maximum mean discrepancy and radial basis #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@VolodyaCO what's the motivation for the following changes?
|
Addressed in meeting. For posterity:
I will review ASAP |
kevinchern
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Vlad!! Nicely implemented and documented.
Let's add the MaximumMeanDiscrepancy a a module and this should be good to go.
| return self._kernel(xy) | ||
|
|
||
|
|
||
| class RBFKernel(Kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename to RadialBasisFunction
| self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) | ||
| self.bandwidth = bandwidth | ||
|
|
||
| def get_bandwidth( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make private and use @torch.nograd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should cite the source that uses this bandwidth-selection heuristic(?)... I thought it was in the InfoVAE or Generative moment matching networks paper but can't find in either now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here's a candidate https://arxiv.org/abs/1707.07269
thisac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @VolodyaCO and @kevinchern! Test fail due to Python 3.9 tests being run (removing 3.9 support in #49, which should fix it).
Unit tests should be expanded. There should at least be test classes for kernels (RBF) and more unit tests for the mmd function. Otherwise, looks good. Just a few, mostly minor, comments.
| """ | ||
| Base class for kernels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor, but to follow conventions in this and other packages.
| """ | |
| Base class for kernels. | |
| """Base class for kernels. |
Same for other docstrings below.
| Returns: | ||
| torch.Tensor: A (n, n) tensor. | ||
| """ | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for pass if there is a docstring.
| pass |
| y (torch.Tensor): A (n_y, f1, f2, ...) tensor. | ||
| Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| y (torch.Tensor): A (n_y, f1, f2, ...) tensor. | |
| Returns: | |
| y (torch.Tensor): A (n_y, f1, f2, ...) tensor. | |
| Returns: |
| Returns: | ||
| torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. | ||
| """ | ||
| assert x.shape[1:] == y.shape[1:], ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be an if-check with an exception raised instead of an assert.
| torch.Tensor | float: The base bandwidth parameter. | ||
| """ | ||
| if self.bandwidth is None: | ||
| assert l2_distance_matrix is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly here, raise an exception instead of assert.
| __all__ = ["Kernel", "RBFKernel", "mmd_loss"] | ||
|
|
||
|
|
||
| class Kernel(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should Kernels (and RBF) be in a kernels.py instead of mmd.py?
| soft = logits | ||
| result = hard - soft.detach() + soft | ||
| # Now we need to repeat the result n_samples times along a new dimension | ||
| return repeat(result, "b ... -> b n ...", n=n_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we absolutely need repeat here? Seems a bit cumbersome to add einops as a test dependency just for this test. 🤔
TODOs:
store_config