Skip to content

Conversation

stes
Copy link
Member

@stes stes commented Aug 2, 2025

This PR adds a new argument, num_negatives, to both the Loader and CEBRA classes (torch and sklearn API). This allows to stabilize training behavior by providing additional negative examples to the InfoNCE loss independent of the batch size. We leveraged this logic for the models trained in DCL.

Behavior

If the num_negatives = None (the default), the previous behavior is obtained for backwards compatibility and Loader.batch_size negative examples are drawn. If a different value is set, then the number of negative examples in all loaders will be set to num_negatives instead of Loader.batch_size.

The goodness of fit computation was also adapted to use the num_negatives.

API modification

While implementing this functionality, I noticed an inconsistency between single_session and multi_session solvers. In the single session, we passed self.batch_size through the get_indices function, while in multi_session we use self.batch_size directly. The 2nd behavior makes more sense in the context of the general class design. I deprecated passing the num_samples parameter and adapted the samplers accordingly.

@cla-bot cla-bot bot added the CLA signed label Aug 2, 2025
@stes stes self-assigned this Aug 2, 2025
@stes stes added the enhancement New feature or request label Aug 2, 2025
@stes stes marked this pull request as ready for review August 2, 2025 15:17
@stes
Copy link
Member Author

stes commented Aug 2, 2025

@cla-bot check

Copy link

cla-bot bot commented Aug 2, 2025

Thanks for tagging me. I looked for a signed form under your signature again, and updated the status on this PR. If the check was successful, no further action is needed. If the check was unsuccessful, please see the instructions in my first comment.

@MMathisLab MMathisLab requested a review from Copilot August 11, 2025 06:23
Copilot

This comment was marked as outdated.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces the ability to decouple batch size from the number of negative samples in contrastive learning by adding a new num_negatives parameter to both the Loader and CEBRA classes. This allows for more stable training behavior by providing additional negative examples to the InfoNCE loss independent of the batch size.

  • Adds num_negatives parameter to Loader base class and CEBRA class APIs
  • Updates all loader implementations to use num_negatives instead of duplicating batch_size
  • Modifies goodness of fit computation to use num_negatives instead of batch_size

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
cebra/data/base.py Adds num_negatives field to base Loader class with validation and deprecates num_samples parameter
cebra/data/single_session.py Updates single session loaders to use num_negatives and removes num_samples parameter
cebra/data/multi_session.py Updates multi-session loaders to use num_negatives and adds validation for unified loader
cebra/data/multiobjective.py Updates multiobjective loaders to use num_negatives instead of batch_size
cebra/integrations/sklearn/cebra.py Adds num_negatives parameter to CEBRA class and passes it to loaders
cebra/integrations/sklearn/metrics.py Updates goodness of fit computation to use num_negatives_ instead of batch_size
tests/test_loader.py Comprehensive test updates to validate num_negatives functionality across all loader types
tests/test_sklearn.py Adds basic test for num_negatives parameter in CEBRA class
tests/test_sklearn_metrics.py Updates goodness of fit tests to validate num_negatives behavior

@@ -154,13 +155,15 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.
num_negatives: The number of negative samples. If None, defaults to num_samples.

Returns:
Indices for reference, positive and negatives samples.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring mentions 'num_samples' but this parameter no longer exists in the function signature. The docstring should be updated to reflect that num_negatives is now an instance attribute.

Suggested change
Indices for reference, positive and negatives samples.
The number of reference samples (batch size) and the number of negative samples
are determined by the instance attributes ``batch_size`` and ``num_negatives``, respectively.
Returns:
Indices for reference, positive and negative samples.

Copilot uses AI. Check for mistakes.

@@ -319,6 +323,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.
num_negatives: The number of negative samples. If None, defaults to num_samples.

Returns:
Indices for reference, positive and negatives samples.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, the docstring references 'num_samples' which no longer exists. This should be updated to reflect the current implementation.

Suggested change
Indices for reference, positive and negatives samples.
The number of reference samples (batch size) is determined by the
instance's ``batch_size`` attribute. The number of negative samples
is determined by the instance's ``num_negatives`` attribute.
Returns:
Indices for reference, positive and negatives samples as a
:py:class:`cebra.data.datatypes.BatchIndex`.

Copilot uses AI. Check for mistakes.

@@ -436,6 +444,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.
num_negatives: The number of negative samples. If None, defaults to num_samples.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another instance where the docstring incorrectly references 'num_samples'. This should be updated to reflect that num_negatives is an instance attribute.

Copilot uses AI. Check for mistakes.


Note:
From version 0.7.0 onwards, specifying the ``num_samples``
directly is deprecated and will be removed in version 0.8.0.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method signature still accepts 'num_samples' parameter even though it's deprecated. Consider removing this parameter entirely or making it keyword-only to prevent accidental usage.

Copilot uses AI. Check for mistakes.

@@ -228,19 +228,15 @@ def infonce_to_goodness_of_fit(
)
if not hasattr(model, "state_dict_"):
raise RuntimeError("Fit the CEBRA model first.")
if model.batch_size is None:
if model.num_negatives_ is None:
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is duplicated - the same condition and error message appear at lines 231-236. The second check (lines 239-243) should be removed as it's redundant.

Copilot uses AI. Check for mistakes.

@MMathisLab MMathisLab self-requested a review August 27, 2025 12:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA signed enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant