Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""Base classes for datasets and loaders."""

import abc
from typing import Iterator

import literate_dataclasses as dataclasses
import torch
Expand Down Expand Up @@ -239,6 +240,12 @@ class Loader(abc.ABC, cebra.io.HasDevice):
batch_size: int = dataclasses.field(default=None,
doc="""The total batch size.""")

num_negatives: int = dataclasses.field(
default=None,
doc=("The number of negative samples to draw for each reference. "
"If not specified, the batch size is used."),
)

def __post_init__(self):
if self.num_steps is None or self.num_steps <= 0:
raise ValueError(
Expand All @@ -248,28 +255,41 @@ def __post_init__(self):
raise ValueError(
f"Batch size has to be None, or a non-negative value. Got {self.batch_size}."
)
if self.num_negatives is not None and self.num_negatives <= 0:
raise ValueError(
f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}."
)

if self.num_negatives is None:
self.num_negatives = self.batch_size

def __len__(self):
"""The number of batches returned when calling as an iterator."""
return self.num_steps

def __iter__(self) -> Batch:
def __iter__(self) -> Iterator[Batch]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch(index)

@abc.abstractmethod
def get_indices(self, num_samples: int):
def get_indices(self, num_samples: int = None):
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_indices method signature change introduces a potential breaking change by making num_samples optional with a default of None. While the deprecation is documented, the method should handle the case where num_samples is passed but shouldn't be used, potentially issuing a deprecation warning to guide users toward the new API.

Copilot uses AI. Check for mistakes.

"""Sample and return the specified number of indices.

The elements of the returned `BatchIndex` will be used to index the
`dataset` of this data loader.

Args:
num_samples: The size of each of the reference, positive and
negative samples.
num_samples: Deprecated. Use ``batch_size`` on the instance level
instead.

Returns:
batch indices for the reference, positive and negative sample.

Note:
From version 0.7.0 onwards, specifying the ``num_samples``
directly is deprecated and will be removed in version 0.8.0.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method signature still accepts 'num_samples' parameter even though it's deprecated. Consider removing this parameter entirely or making it keyword-only to prevent accidental usage.

Copilot uses AI. Check for mistakes.

Please set ``batch_size`` and ``num_negatives`` on the instance
level instead.
"""
raise NotImplementedError()
22 changes: 18 additions & 4 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ def __post_init__(self):
super().__post_init__()
self.sampler = cebra.distributions.MultisessionSampler(
self.dataset, self.time_offset)
if self.num_negatives is None:
self.num_negatives = self.batch_size

def get_indices(self, num_samples: int) -> List[BatchIndex]:
# NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument
# is not used in the multi-session case, which is different to the single session samples.
def get_indices(self) -> List[BatchIndex]:
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.num_negatives)
pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx)

ref_idx = torch.from_numpy(ref_idx)
Expand Down Expand Up @@ -192,8 +196,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
# Overwrite sampler with the discrete implementation
# Generalize MultisessionSampler to avoid doing this?
def __post_init__(self):
# NOTE(stes): __post_init__ from superclass is intentionally not called.
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
self.dataset)
if self.num_negatives is None:
self.num_negatives = self.batch_size

@property
def index(self):
Expand Down Expand Up @@ -229,7 +236,14 @@ def __post_init__(self):
self.sampler = cebra.distributions.UnifiedSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
if self.batch_size is not None and self.batch_size < 2:
raise ValueError("UnifiedLoader does not support batch_size < 2.")

if self.num_negatives is not None and self.num_negatives < 2:
raise ValueError(
"UnifiedLoader does not support num_negatives < 2.")

def get_indices(self) -> BatchIndex:
"""Sample and return the specified number of indices.

The elements of the returned ``BatchIndex`` will be used to index the
Expand All @@ -251,7 +265,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Batch indices for the reference, positive and negative samples.
"""
ref_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.batch_size)
neg_idx = self.sampler.sample_prior(self.num_negatives)

pos_idx = self.sampler.sample_conditional(ref_idx)

Expand Down
22 changes: 13 additions & 9 deletions cebra/data/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
# limitations under the License.
#

from typing import Iterator

import literate_dataclasses as dataclasses

import cebra.data as cebra_data
import cebra.distributions
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex
from cebra.distributions.continuous import Prior

Expand Down Expand Up @@ -71,9 +74,9 @@ def __post_init__(self):
def add_config(self, config):
self.labels.append(config['label'])

def get_indices(self, num_samples: int):
def get_indices(self) -> BatchIndex:
if self.sampling_mode_supervised == "ref_shared":
reference_idx = self.prior.sample_prior(num_samples)
reference_idx = self.prior.sample_prior(self.batch_size)
else:
raise ValueError(
f"Sampling mode {self.sampling_mode_supervised} is not implemented."
Expand All @@ -87,9 +90,9 @@ def get_indices(self, num_samples: int):

return batch_index

def __iter__(self):
def __iter__(self) -> Iterator[Batch]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch_supervised(index, self.labels)


Expand Down Expand Up @@ -142,13 +145,14 @@ def add_config(self, config):

self.distributions.append(distribution)

def get_indices(self, num_samples: int):
def get_indices(self) -> BatchIndex:
"""Sample and return the specified number of indices."""

if self.sampling_mode_contrastive == "refneg_shared":
ref_and_neg = self.prior.sample_prior(num_samples * 2)
reference_idx = ref_and_neg[:num_samples]
negative_idx = ref_and_neg[num_samples:]
ref_and_neg = self.prior.sample_prior(self.batch_size +
self.num_negatives)
reference_idx = ref_and_neg[:self.batch_size]
negative_idx = ref_and_neg[self.batch_size:]

positives_idx = []
for distribution in self.distributions:
Expand All @@ -169,5 +173,5 @@ def get_indices(self, num_samples: int):

def __iter__(self):
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
index = self.get_indices()
yield self.dataset.load_batch_contrastive(index)
57 changes: 35 additions & 22 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import abc
import warnings
from typing import Iterator

import literate_dataclasses as dataclasses
import torch
Expand Down Expand Up @@ -138,7 +139,7 @@ def _init_distribution(self):
f"Invalid choice of prior distribution. Got '{self.prior}', but "
f"only accept 'uniform' or 'empirical' as potential values.")

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference samples will be sampled from the empirical or uniform prior
Expand All @@ -154,13 +155,15 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.
num_negatives: The number of negative samples. If None, defaults to num_samples.

Returns:
Indices for reference, positive and negatives samples.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring mentions 'num_samples' but this parameter no longer exists in the function signature. The docstring should be updated to reflect that num_negatives is now an instance attribute.

Suggested change
Indices for reference, positive and negatives samples.
The number of reference samples (batch size) and the number of negative samples
are determined by the instance attributes ``batch_size`` and ``num_negatives``, respectively.
Returns:
Indices for reference, positive and negative samples.

Copilot uses AI. Check for mistakes.

"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
reference = self.index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
Expand Down Expand Up @@ -246,7 +249,7 @@ def _init_distribution(self):
else:
raise ValueError(self.conditional)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -262,9 +265,10 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Returns:
Indices for reference, positive and negatives samples.
"""
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
positive_idx = self.distribution.sample_conditional(reference_idx)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
Expand Down Expand Up @@ -305,7 +309,7 @@ def __post_init__(self):
continuous=self.cindex,
time_delta=self.time_offset)

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
Expand All @@ -319,6 +323,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.
num_negatives: The number of negative samples. If None, defaults to num_samples.

Returns:
Indices for reference, positive and negatives samples.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the previous comment, the docstring references 'num_samples' which no longer exists. This should be updated to reflect the current implementation.

Suggested change
Indices for reference, positive and negatives samples.
The number of reference samples (batch size) is determined by the
instance's ``batch_size`` attribute. The number of negative samples
is determined by the instance's ``num_negatives`` attribute.
Returns:
Indices for reference, positive and negatives samples as a
:py:class:`cebra.data.datatypes.BatchIndex`.

Copilot uses AI. Check for mistakes.

Expand All @@ -328,10 +333,13 @@ def get_indices(self, num_samples: int) -> BatchIndex:
class.
- Sample the negatives with matching discrete variable
"""
reference_idx = self.distribution.sample_prior(num_samples)
reference_idx = self.distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
negative=negative_idx,
positive=self.distribution.sample_conditional(reference_idx),
)

Expand Down Expand Up @@ -421,11 +429,11 @@ def _init_time_distribution(self):
else:
raise ValueError

def get_indices(self, num_samples: int) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference and negative samples will be sampled uniformly from
all available time steps, and a total of ``2*num_samples`` will be
all available time steps, and a total of ``num_samples + num_negatives`` will be
returned for both.

For the positive samples, ``num_samples`` are sampled according to the
Expand All @@ -436,6 +444,7 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Args:
num_samples: The number of samples (batch size) of the returned
:py:class:`cebra.data.datatypes.BatchIndex`.
num_negatives: The number of negative samples. If None, defaults to num_samples.
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another instance where the docstring incorrectly references 'num_samples'. This should be updated to reflect that num_negatives is an instance attribute.

Copilot uses AI. Check for mistakes.


Returns:
Indices for reference, positive and negatives samples.
Expand All @@ -444,9 +453,10 @@ def get_indices(self, num_samples: int) -> BatchIndex:
Add the ``empirical`` vs. ``discrete`` sampling modes to this
class.
"""
reference_idx = self.time_distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference_idx = self.time_distribution.sample_prior(self.batch_size +
self.num_negatives)
negative_idx = reference_idx[self.batch_size:]
reference_idx = reference_idx[:self.batch_size]
behavior_positive_idx = self.behavior_distribution.sample_conditional(
reference_idx)
time_positive_idx = self.time_distribution.sample_conditional(
Expand All @@ -464,13 +474,18 @@ class FullDataLoader(ContinuousDataLoader):

def __post_init__(self):
super().__post_init__()
self.batch_size = None

if self.batch_size is not None:
Copy link
Preview

Copilot AI Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition check for batch_size in FullDataLoader.__post_init__() is incorrect. Since FullDataLoader inherits from ContinuousDataLoader and the base Loader class sets batch_size = None by default when not specified, this check will always be true when batch_size is explicitly set to None in the constructor call. The check should be if self.batch_size is not None: to properly validate that batch_size was not set.

Copilot uses AI. Check for mistakes.

raise ValueError("Batch size cannot be set for FullDataLoader.")
if self.num_negatives is not None:
raise ValueError(
"Number of negatives cannot be set for FullDataLoader.")

@property
def offset(self):
return self.dataset.offset

def get_indices(self, num_samples=None) -> BatchIndex:
def get_indices(self) -> BatchIndex:
"""Samples indices for reference, positive and negative examples.

The reference indices are all available (valid, according to the
Expand All @@ -490,7 +505,6 @@ def get_indices(self, num_samples=None) -> BatchIndex:
Add the ``empirical`` vs. ``discrete`` sampling modes to this
class.
"""
assert num_samples is None

reference_idx = torch.arange(
self.offset.left,
Expand All @@ -504,7 +518,6 @@ def get_indices(self, num_samples=None) -> BatchIndex:
positive=positive_idx,
negative=negative_idx)

def __iter__(self):
def __iter__(self) -> Iterator[BatchIndex]:
for _ in range(len(self)):
index = self.get_indices(num_samples=self.batch_size)
yield index
yield self.get_indices()
12 changes: 12 additions & 0 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ class CEBRA(TransformerMixin, BaseEstimator):
A Tuple of masking types and their corresponding required masking values. The keys are the
names of the Mask instances and formatting should be ``((key, value), (key, value))``.
|Default:| ``None``.
num_negatives (int):
The number of negative samples to use for training. If ``None``, the number of negative samples
will be set to the batch size. |Default:| ``None``.

Example:

Expand Down Expand Up @@ -576,6 +579,7 @@ def __init__(
),
masking_kwargs: Tuple[Tuple[str, Union[float, List[float],
Tuple[float, ...]]], ...] = None,
num_negatives: int = None,
):
self.__dict__.update(locals())

Expand All @@ -592,6 +596,13 @@ def num_sessions(self) -> Optional[int]:
"""
return self.num_sessions_

@property
def num_negatives_(self) -> int:
"""The number of negative examples."""
if self.num_negatives is None:
return self.batch_size
return self.num_negatives

@property
def state_dict_(self) -> dict:
return self.solver_.state_dict()
Expand Down Expand Up @@ -728,6 +739,7 @@ def _prepare_loader(self, dataset: cebra.data.Dataset, max_iterations: int,
dataset=dataset,
batch_size=self.batch_size,
num_steps=max_iterations,
num_negatives=self.num_negatives,
),
extra_kwargs=dict(
time_offsets=self.time_offsets,
Expand Down
Loading
Loading