-
Notifications
You must be signed in to change notification settings - Fork 90
Decouple batch size and number of negatives #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@cla-bot check |
Thanks for tagging me. I looked for a signed form under your signature again, and updated the status on this PR. If the check was successful, no further action is needed. If the check was unsuccessful, please see the instructions in my first comment. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces the ability to decouple batch size from the number of negative samples in contrastive learning by adding a new num_negatives
parameter to both the Loader
and CEBRA
classes. This allows for more stable training behavior by providing additional negative examples to the InfoNCE loss independent of the batch size.
- Adds
num_negatives
parameter toLoader
base class andCEBRA
class APIs - Updates all loader implementations to use
num_negatives
instead of duplicatingbatch_size
- Modifies goodness of fit computation to use
num_negatives
instead ofbatch_size
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
Show a summary per file
File | Description |
---|---|
cebra/data/base.py |
Adds num_negatives field to base Loader class with validation and deprecates num_samples parameter |
cebra/data/single_session.py |
Updates single session loaders to use num_negatives and removes num_samples parameter |
cebra/data/multi_session.py |
Updates multi-session loaders to use num_negatives and adds validation for unified loader |
cebra/data/multiobjective.py |
Updates multiobjective loaders to use num_negatives instead of batch_size |
cebra/integrations/sklearn/cebra.py |
Adds num_negatives parameter to CEBRA class and passes it to loaders |
cebra/integrations/sklearn/metrics.py |
Updates goodness of fit computation to use num_negatives_ instead of batch_size |
tests/test_loader.py |
Comprehensive test updates to validate num_negatives functionality across all loader types |
tests/test_sklearn.py |
Adds basic test for num_negatives parameter in CEBRA class |
tests/test_sklearn_metrics.py |
Updates goodness of fit tests to validate num_negatives behavior |
@@ -154,13 +155,15 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||
Args: | |||
num_samples: The number of samples (batch size) of the returned | |||
:py:class:`cebra.data.datatypes.BatchIndex`. | |||
num_negatives: The number of negative samples. If None, defaults to num_samples. | |||
|
|||
Returns: | |||
Indices for reference, positive and negatives samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring mentions 'num_samples' but this parameter no longer exists in the function signature. The docstring should be updated to reflect that num_negatives is now an instance attribute.
Indices for reference, positive and negatives samples. | |
The number of reference samples (batch size) and the number of negative samples | |
are determined by the instance attributes ``batch_size`` and ``num_negatives``, respectively. | |
Returns: | |
Indices for reference, positive and negative samples. |
Copilot uses AI. Check for mistakes.
@@ -319,6 +323,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||
Args: | |||
num_samples: The number of samples (batch size) of the returned | |||
:py:class:`cebra.data.datatypes.BatchIndex`. | |||
num_negatives: The number of negative samples. If None, defaults to num_samples. | |||
|
|||
Returns: | |||
Indices for reference, positive and negatives samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, the docstring references 'num_samples' which no longer exists. This should be updated to reflect the current implementation.
Indices for reference, positive and negatives samples. | |
The number of reference samples (batch size) is determined by the | |
instance's ``batch_size`` attribute. The number of negative samples | |
is determined by the instance's ``num_negatives`` attribute. | |
Returns: | |
Indices for reference, positive and negatives samples as a | |
:py:class:`cebra.data.datatypes.BatchIndex`. |
Copilot uses AI. Check for mistakes.
@@ -436,6 +444,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||
Args: | |||
num_samples: The number of samples (batch size) of the returned | |||
:py:class:`cebra.data.datatypes.BatchIndex`. | |||
num_negatives: The number of negative samples. If None, defaults to num_samples. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another instance where the docstring incorrectly references 'num_samples'. This should be updated to reflect that num_negatives is an instance attribute.
Copilot uses AI. Check for mistakes.
|
||
Note: | ||
From version 0.7.0 onwards, specifying the ``num_samples`` | ||
directly is deprecated and will be removed in version 0.8.0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method signature still accepts 'num_samples' parameter even though it's deprecated. Consider removing this parameter entirely or making it keyword-only to prevent accidental usage.
Copilot uses AI. Check for mistakes.
@@ -228,19 +228,15 @@ def infonce_to_goodness_of_fit( | |||
) | |||
if not hasattr(model, "state_dict_"): | |||
raise RuntimeError("Fit the CEBRA model first.") | |||
if model.batch_size is None: | |||
if model.num_negatives_ is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is duplicated - the same condition and error message appear at lines 231-236. The second check (lines 239-243) should be removed as it's redundant.
Copilot uses AI. Check for mistakes.
This PR adds a new argument,
num_negatives
, to both theLoader
andCEBRA
classes (torch and sklearn API). This allows to stabilize training behavior by providing additional negative examples to the InfoNCE loss independent of the batch size. We leveraged this logic for the models trained in DCL.Behavior
If the
num_negatives = None
(the default), the previous behavior is obtained for backwards compatibility andLoader.batch_size
negative examples are drawn. If a different value is set, then the number of negative examples in all loaders will be set tonum_negatives
instead ofLoader.batch_size
.The goodness of fit computation was also adapted to use the
num_negatives
.API modification
While implementing this functionality, I noticed an inconsistency between
single_session
andmulti_session
solvers. In the single session, we passedself.batch_size
through theget_indices
function, while inmulti_session
we useself.batch_size
directly. The 2nd behavior makes more sense in the context of the general class design. I deprecated passing thenum_samples
parameter and adapted the samplers accordingly.