diff --git a/cebra/data/base.py b/cebra/data/base.py index 51199ce..518f7b4 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -22,6 +22,7 @@ """Base classes for datasets and loaders.""" import abc +from typing import Iterator import literate_dataclasses as dataclasses import torch @@ -239,6 +240,12 @@ class Loader(abc.ABC, cebra.io.HasDevice): batch_size: int = dataclasses.field(default=None, doc="""The total batch size.""") + num_negatives: int = dataclasses.field( + default=None, + doc=("The number of negative samples to draw for each reference. " + "If not specified, the batch size is used."), + ) + def __post_init__(self): if self.num_steps is None or self.num_steps <= 0: raise ValueError( @@ -248,28 +255,41 @@ def __post_init__(self): raise ValueError( f"Batch size has to be None, or a non-negative value. Got {self.batch_size}." ) + if self.num_negatives is not None and self.num_negatives <= 0: + raise ValueError( + f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}." + ) + + if self.num_negatives is None: + self.num_negatives = self.batch_size def __len__(self): """The number of batches returned when calling as an iterator.""" return self.num_steps - def __iter__(self) -> Batch: + def __iter__(self) -> Iterator[Batch]: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices() yield self.dataset.load_batch(index) @abc.abstractmethod - def get_indices(self, num_samples: int): + def get_indices(self, num_samples: int = None): """Sample and return the specified number of indices. The elements of the returned `BatchIndex` will be used to index the `dataset` of this data loader. Args: - num_samples: The size of each of the reference, positive and - negative samples. + num_samples: Deprecated. Use ``batch_size`` on the instance level + instead. Returns: batch indices for the reference, positive and negative sample. + + Note: + From version 0.7.0 onwards, specifying the ``num_samples`` + directly is deprecated and will be removed in version 0.8.0. + Please set ``batch_size`` and ``num_negatives`` on the instance + level instead. """ raise NotImplementedError() diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index f33ad6e..c6561ee 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -155,10 +155,14 @@ def __post_init__(self): super().__post_init__() self.sampler = cebra.distributions.MultisessionSampler( self.dataset, self.time_offset) + if self.num_negatives is None: + self.num_negatives = self.batch_size - def get_indices(self, num_samples: int) -> List[BatchIndex]: + # NOTE(stes): In the longer run, we need to unify the API here; the num_samples argument + # is not used in the multi-session case, which is different to the single session samples. + def get_indices(self) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) - neg_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.num_negatives) pos_idx, idx, idx_rev = self.sampler.sample_conditional(ref_idx) ref_idx = torch.from_numpy(ref_idx) @@ -192,8 +196,11 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader): # Overwrite sampler with the discrete implementation # Generalize MultisessionSampler to avoid doing this? def __post_init__(self): + # NOTE(stes): __post_init__ from superclass is intentionally not called. self.sampler = cebra.distributions.DiscreteMultisessionSampler( self.dataset) + if self.num_negatives is None: + self.num_negatives = self.batch_size @property def index(self): @@ -229,7 +236,14 @@ def __post_init__(self): self.sampler = cebra.distributions.UnifiedSampler( self.dataset, self.time_offset) - def get_indices(self, num_samples: int) -> BatchIndex: + if self.batch_size is not None and self.batch_size < 2: + raise ValueError("UnifiedLoader does not support batch_size < 2.") + + if self.num_negatives is not None and self.num_negatives < 2: + raise ValueError( + "UnifiedLoader does not support num_negatives < 2.") + + def get_indices(self) -> BatchIndex: """Sample and return the specified number of indices. The elements of the returned ``BatchIndex`` will be used to index the @@ -251,7 +265,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: Batch indices for the reference, positive and negative samples. """ ref_idx = self.sampler.sample_prior(self.batch_size) - neg_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.num_negatives) pos_idx = self.sampler.sample_conditional(ref_idx) diff --git a/cebra/data/multiobjective.py b/cebra/data/multiobjective.py index f700d1c..4ccfb63 100644 --- a/cebra/data/multiobjective.py +++ b/cebra/data/multiobjective.py @@ -20,10 +20,13 @@ # limitations under the License. # +from typing import Iterator + import literate_dataclasses as dataclasses import cebra.data as cebra_data import cebra.distributions +from cebra.data.datatypes import Batch from cebra.data.datatypes import BatchIndex from cebra.distributions.continuous import Prior @@ -71,9 +74,9 @@ def __post_init__(self): def add_config(self, config): self.labels.append(config['label']) - def get_indices(self, num_samples: int): + def get_indices(self) -> BatchIndex: if self.sampling_mode_supervised == "ref_shared": - reference_idx = self.prior.sample_prior(num_samples) + reference_idx = self.prior.sample_prior(self.batch_size) else: raise ValueError( f"Sampling mode {self.sampling_mode_supervised} is not implemented." @@ -87,9 +90,9 @@ def get_indices(self, num_samples: int): return batch_index - def __iter__(self): + def __iter__(self) -> Iterator[Batch]: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices() yield self.dataset.load_batch_supervised(index, self.labels) @@ -142,13 +145,14 @@ def add_config(self, config): self.distributions.append(distribution) - def get_indices(self, num_samples: int): + def get_indices(self) -> BatchIndex: """Sample and return the specified number of indices.""" if self.sampling_mode_contrastive == "refneg_shared": - ref_and_neg = self.prior.sample_prior(num_samples * 2) - reference_idx = ref_and_neg[:num_samples] - negative_idx = ref_and_neg[num_samples:] + ref_and_neg = self.prior.sample_prior(self.batch_size + + self.num_negatives) + reference_idx = ref_and_neg[:self.batch_size] + negative_idx = ref_and_neg[self.batch_size:] positives_idx = [] for distribution in self.distributions: @@ -169,5 +173,5 @@ def get_indices(self, num_samples: int): def __iter__(self): for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) + index = self.get_indices() yield self.dataset.load_batch_contrastive(index) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7e4ad2f..2daef64 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -27,6 +27,7 @@ import abc import warnings +from typing import Iterator import literate_dataclasses as dataclasses import torch @@ -138,7 +139,7 @@ def _init_distribution(self): f"Invalid choice of prior distribution. Got '{self.prior}', but " f"only accept 'uniform' or 'empirical' as potential values.") - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference samples will be sampled from the empirical or uniform prior @@ -154,13 +155,15 @@ def get_indices(self, num_samples: int) -> BatchIndex: Args: num_samples: The number of samples (batch size) of the returned :py:class:`cebra.data.datatypes.BatchIndex`. + num_negatives: The number of negative samples. If None, defaults to num_samples. Returns: Indices for reference, positive and negatives samples. """ - reference_idx = self.distribution.sample_prior(num_samples * 2) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] reference = self.index[reference_idx] positive_idx = self.distribution.sample_conditional(reference) return BatchIndex(reference=reference_idx, @@ -246,7 +249,7 @@ def _init_distribution(self): else: raise ValueError(self.conditional) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -262,9 +265,10 @@ def get_indices(self, num_samples: int) -> BatchIndex: Returns: Indices for reference, positive and negatives samples. """ - reference_idx = self.distribution.sample_prior(num_samples * 2) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] positive_idx = self.distribution.sample_conditional(reference_idx) return BatchIndex(reference=reference_idx, positive=positive_idx, @@ -305,7 +309,7 @@ def __post_init__(self): continuous=self.cindex, time_delta=self.time_offset) - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from @@ -319,6 +323,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: Args: num_samples: The number of samples (batch size) of the returned :py:class:`cebra.data.datatypes.BatchIndex`. + num_negatives: The number of negative samples. If None, defaults to num_samples. Returns: Indices for reference, positive and negatives samples. @@ -328,10 +333,13 @@ def get_indices(self, num_samples: int) -> BatchIndex: class. - Sample the negatives with matching discrete variable """ - reference_idx = self.distribution.sample_prior(num_samples) + reference_idx = self.distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] return BatchIndex( reference=reference_idx, - negative=self.distribution.sample_prior(num_samples), + negative=negative_idx, positive=self.distribution.sample_conditional(reference_idx), ) @@ -421,11 +429,11 @@ def _init_time_distribution(self): else: raise ValueError - def get_indices(self, num_samples: int) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference and negative samples will be sampled uniformly from - all available time steps, and a total of ``2*num_samples`` will be + all available time steps, and a total of ``num_samples + num_negatives`` will be returned for both. For the positive samples, ``num_samples`` are sampled according to the @@ -436,6 +444,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: Args: num_samples: The number of samples (batch size) of the returned :py:class:`cebra.data.datatypes.BatchIndex`. + num_negatives: The number of negative samples. If None, defaults to num_samples. Returns: Indices for reference, positive and negatives samples. @@ -444,9 +453,10 @@ def get_indices(self, num_samples: int) -> BatchIndex: Add the ``empirical`` vs. ``discrete`` sampling modes to this class. """ - reference_idx = self.time_distribution.sample_prior(num_samples * 2) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] + reference_idx = self.time_distribution.sample_prior(self.batch_size + + self.num_negatives) + negative_idx = reference_idx[self.batch_size:] + reference_idx = reference_idx[:self.batch_size] behavior_positive_idx = self.behavior_distribution.sample_conditional( reference_idx) time_positive_idx = self.time_distribution.sample_conditional( @@ -464,13 +474,18 @@ class FullDataLoader(ContinuousDataLoader): def __post_init__(self): super().__post_init__() - self.batch_size = None + + if self.batch_size is not None: + raise ValueError("Batch size cannot be set for FullDataLoader.") + if self.num_negatives is not None: + raise ValueError( + "Number of negatives cannot be set for FullDataLoader.") @property def offset(self): return self.dataset.offset - def get_indices(self, num_samples=None) -> BatchIndex: + def get_indices(self) -> BatchIndex: """Samples indices for reference, positive and negative examples. The reference indices are all available (valid, according to the @@ -490,7 +505,6 @@ def get_indices(self, num_samples=None) -> BatchIndex: Add the ``empirical`` vs. ``discrete`` sampling modes to this class. """ - assert num_samples is None reference_idx = torch.arange( self.offset.left, @@ -504,7 +518,6 @@ def get_indices(self, num_samples=None) -> BatchIndex: positive=positive_idx, negative=negative_idx) - def __iter__(self): + def __iter__(self) -> Iterator[BatchIndex]: for _ in range(len(self)): - index = self.get_indices(num_samples=self.batch_size) - yield index + yield self.get_indices() diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 98e5674..25ee6e0 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -501,6 +501,9 @@ class CEBRA(TransformerMixin, BaseEstimator): A Tuple of masking types and their corresponding required masking values. The keys are the names of the Mask instances and formatting should be ``((key, value), (key, value))``. |Default:| ``None``. + num_negatives (int): + The number of negative samples to use for training. If ``None``, the number of negative samples + will be set to the batch size. |Default:| ``None``. Example: @@ -576,6 +579,7 @@ def __init__( ), masking_kwargs: Tuple[Tuple[str, Union[float, List[float], Tuple[float, ...]]], ...] = None, + num_negatives: int = None, ): self.__dict__.update(locals()) @@ -592,6 +596,13 @@ def num_sessions(self) -> Optional[int]: """ return self.num_sessions_ + @property + def num_negatives_(self) -> int: + """The number of negative examples.""" + if self.num_negatives is None: + return self.batch_size + return self.num_negatives + @property def state_dict_(self) -> dict: return self.solver_.state_dict() @@ -728,6 +739,7 @@ def _prepare_loader(self, dataset: cebra.data.Dataset, max_iterations: int, dataset=dataset, batch_size=self.batch_size, num_steps=max_iterations, + num_negatives=self.num_negatives, ), extra_kwargs=dict( time_offsets=self.time_offsets, diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index d8fd791..d072b0a 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -100,12 +100,12 @@ def infonce_loss( solver.to(cebra_model.device_) avg_loss = solver.validation(loader=loader, session_id=session_id) if correct_by_batchsize: - if cebra_model.batch_size is None: + if cebra_model.num_negatives_ is None: raise ValueError( "Batch size is None, please provide a model with a batch size to correct the InfoNCE." ) else: - avg_loss = avg_loss - np.log(cebra_model.batch_size) + avg_loss = avg_loss - np.log(cebra_model.num_negatives_) return avg_loss @@ -211,7 +211,7 @@ def infonce_to_goodness_of_fit( Args: infonce: The InfoNCE loss, either a single value or an iterable of values. model: The trained CEBRA model. - batch_size: The batch size used to train the model. + batch_size: The batch size (or number of negatives, if different from the batch size) used to train the model. num_sessions: The number of sessions used to train the model. Returns: @@ -228,19 +228,15 @@ def infonce_to_goodness_of_fit( ) if not hasattr(model, "state_dict_"): raise RuntimeError("Fit the CEBRA model first.") - if model.batch_size is None: + if model.num_negatives_ is None: raise ValueError( "Computing the goodness of fit is not yet supported for " "models trained on the full dataset (batchsize = None). ") - batch_size = model.batch_size + batch_size = model.num_negatives_ num_sessions = model.num_sessions_ if num_sessions is None: num_sessions = 1 - if model.batch_size is None: - raise ValueError( - "Computing the goodness of fit is not yet supported for " - "models trained on the full dataset (batchsize = None). ") else: if batch_size is None or num_sessions is None: raise ValueError( diff --git a/tests/test_loader.py b/tests/test_loader.py index cb6be9a..8eaa9f4 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -29,6 +29,43 @@ BATCH_SIZE = 32 NUMS_NEURAL = [3, 4, 5] +SINGLE_SESSION_LOADERS = [ + ("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader), + ("demo-mixed", cebra.data.MixedDataLoader), +] +MULTI_SESSION_LOADERS = [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader), +] +LOADERS = SINGLE_SESSION_LOADERS + MULTI_SESSION_LOADERS + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), +] + + +def _setup_functional_loader_test(data_name, loader_initfunc, device, + batch_size, num_negatives): + data = cebra.datasets.init(data_name) + data.to(device) + if num_negatives == "do not pass": + loader = loader_initfunc(data, num_steps=10, batch_size=batch_size) + else: + loader = loader_initfunc(data, + num_steps=10, + batch_size=batch_size, + num_negatives=num_negatives) + + if num_negatives is None or num_negatives == "do not pass": + assert loader.num_negatives == batch_size + expected_num_negatives = batch_size + else: + assert loader.num_negatives == num_negatives + expected_num_negatives = num_negatives + + _assert_dataset_on_correct_device(loader, device) + + return loader, expected_num_negatives class LoadSpeed: @@ -135,16 +172,7 @@ def _to_str(val): @_util.parametrize_device -@pytest.mark.parametrize( - "data_name, loader_initfunc", - [ - ("demo-discrete", cebra.data.DiscreteDataLoader), - ("demo-continuous", cebra.data.ContinuousDataLoader), - ("demo-mixed", cebra.data.MixedDataLoader), - ("demo-continuous-multisession", cebra.data.MultiSessionLoader), - ("demo-continuous-unified", cebra.data.UnifiedLoader), - ], -) +@pytest.mark.parametrize("data_name, loader_initfunc", LOADERS) def test_device(data_name, loader_initfunc, device): if not torch.cuda.is_available(): pytest.skip("Test only possible with CUDA.") @@ -158,8 +186,7 @@ def test_device(data_name, loader_initfunc, device): assert loader.dataset == dataset _assert_device(loader.device, device) _assert_device(loader.dataset.device, device) - - _assert_device(loader.get_indices(10).reference.device, device) + _assert_device(loader.get_indices().reference.device, device) @_util.parametrize_device @@ -206,44 +233,34 @@ def _check_attributes(obj, is_list=False): @_util.parametrize_device -@pytest.mark.parametrize( - "data_name, loader_initfunc", - [ - ("demo-discrete", cebra.data.DiscreteDataLoader), - ("demo-continuous", cebra.data.ContinuousDataLoader), - ("demo-mixed", cebra.data.MixedDataLoader), - ], -) -def test_singlesession_loader(data_name, loader_initfunc, device): - data = cebra.datasets.init(data_name) - data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) - _assert_dataset_on_correct_device(loader, device) +@pytest.mark.parametrize("data_name, loader_initfunc", SINGLE_SESSION_LOADERS) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("num_negatives", [None, 1, 32, "do not pass"]) +def test_singlesession_loader(data_name, loader_initfunc, device, batch_size, + num_negatives): - index = loader.get_indices(100) + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) + + index = loader.get_indices() _check_attributes(index) for batch in loader: _check_attributes(batch) - assert len(batch.positive) == BATCH_SIZE + assert len(batch.positive) == batch_size + assert len(batch.reference) == batch_size + assert len(batch.negative) == expected_num_negatives @_util.parametrize_device -@pytest.mark.parametrize( - "data_name, loader_initfunc", - [ - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader), - ("demo-discrete-multisession", - cebra.data.DiscreteMultiSessionDataLoader), - ], -) -def test_multisession_loader(data_name, loader_initfunc, device): - data = cebra.datasets.init(data_name) - data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) +@pytest.mark.parametrize("data_name, loader_initfunc", MULTI_SESSION_LOADERS) +@pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("num_negatives", [None, 1, 32, 33, "do not pass"]) +def test_multisession_loader(data_name, loader_initfunc, device, batch_size, + num_negatives): - _assert_dataset_on_correct_device(loader, device) + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) # Check the sampler assert hasattr(loader, "sampler") @@ -260,7 +277,7 @@ def test_multisession_loader(data_name, loader_initfunc, device): batch = next(iter(loader)) for i, n_neurons in enumerate(NUMS_NEURAL): - assert batch[i].reference.shape == (BATCH_SIZE, n_neurons, 10) + assert batch[i].reference.shape == (batch_size, n_neurons, 10) def _mix(array, idx): shape = array.shape @@ -276,18 +293,18 @@ def _process(batch, feature_dim=1): dim=0).repeat(1, 1, feature_dim) dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, BATCH_SIZE, 6) + assert dummy_prediction.shape == (3, batch_size, 6) _mix(dummy_prediction, batch[0].index) - index = loader.get_indices(100) - #print(index[0]) - #print(type(index)) + index = loader.get_indices() _check_attributes(index, is_list=False) for batch in loader: _check_attributes(batch, is_list=True) for session_batch in batch: - assert len(session_batch.positive) == BATCH_SIZE + assert len(session_batch.positive) == batch_size + assert len(session_batch.reference) == batch_size + assert len(session_batch.negative) == expected_num_negatives @_util.parametrize_device @@ -297,12 +314,14 @@ def _process(batch, feature_dim=1): ("demo-continuous-unified", cebra.data.UnifiedLoader), ], ) -def test_unified_loader(data_name, loader_initfunc, device): - data = cebra.datasets.init(data_name) - data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) +# TODO(stes): unified sampler breaks for batch_size = 1; tested further below +@pytest.mark.parametrize("batch_size", [2, 32, 100]) +@pytest.mark.parametrize("num_negatives", [None, 2, 32, 33, "do not pass"]) +def test_unified_loader_sampler(data_name, loader_initfunc, device, batch_size, + num_negatives): - _assert_dataset_on_correct_device(loader, device) + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) # Check the sampler num_samples = 100 @@ -334,11 +353,38 @@ def test_unified_loader(data_name, loader_initfunc, device): pos_idx = loader.sampler.sample_conditional(all_ref_idx) assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + +@_util.parametrize_device +@pytest.mark.parametrize( + "data_name, loader_initfunc", + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ], +) +# TODO(stes): unified sampler breaks for batch_size = 1 +@pytest.mark.parametrize("batch_size", [1, 32, 100]) +@pytest.mark.parametrize("num_negatives", [None, 1, 32, 33, "do not pass"]) +def test_unified_loader(data_name, loader_initfunc, device, batch_size, + num_negatives): + + if batch_size == 1 or num_negatives == 1: + with pytest.raises(ValueError, + match=r"UnifiedLoader does not support .* < 2"): + _setup_functional_loader_test(data_name, loader_initfunc, device, + batch_size, num_negatives) + pytest.skip( + "UnifiedLoader does not support batch_size < 2 or num_negatives < 2." + ) + + loader, expected_num_negatives = _setup_functional_loader_test( + data_name, loader_initfunc, device, batch_size, num_negatives) + # Check the batch batch = next(iter(loader)) - assert batch.reference.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) - assert batch.positive.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) - assert batch.negative.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.reference.shape == (batch_size, sum(NUMS_NEURAL), 10) + assert batch.positive.shape == (batch_size, sum(NUMS_NEURAL), 10) + assert batch.negative.shape == (expected_num_negatives, sum(NUMS_NEURAL), + 10) - index = loader.get_indices(100) + index = loader.get_indices() _check_attributes(index, is_list=False) diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index c3d2095..dfa09da 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1544,3 +1544,20 @@ def test_last_incomplete_batch_smaller_than_offset(): model.fit(train.neural, train.continuous) _ = model.transform(train.neural, batch_size=300) + + +@pytest.mark.parametrize("batch_size,num_negatives", [ + (None, None), + (100, None), + (100, 100), +]) +def test_num_negatives(batch_size, num_negatives): + train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), + continuous=np.random.rand(20111, 2)) + + model = cebra.CEBRA(max_iterations=2, + batch_size=batch_size, + num_negatives=num_negatives, + device="cpu") + model.fit(train.neural, train.continuous) + _ = model.transform(train.neural) diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 10c6245..bb71b1f 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -482,14 +482,22 @@ def _fit_and_get_history(X, y): @pytest.mark.parametrize("seed", [42, 24, 10]) -def test_infonce_to_goodness_of_fit(seed): +@pytest.mark.parametrize("batch_size", [100, 200]) +@pytest.mark.parametrize("num_negatives", [None, 100, 200]) +def test_infonce_to_goodness_of_fit(seed, batch_size, num_negatives): """Test the conversion from InfoNCE loss to goodness of fit metric.""" + nats_to_bits = np.log2(np.e) + # Test with model cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture="offset10-model", max_iterations=5, - batch_size=128, + batch_size=batch_size, + num_negatives=num_negatives, ) + if num_negatives is None: + num_negatives = batch_size + generator = torch.Generator().manual_seed(seed) X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) cebra_model.fit(X) @@ -498,6 +506,7 @@ def test_infonce_to_goodness_of_fit(seed): gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, model=cebra_model) assert isinstance(gof, float) + assert np.isclose(gof, (np.log(num_negatives) - 1.0) * nats_to_bits) # Test array of values infonce_values = np.array([1.0, 2.0, 3.0]) @@ -505,12 +514,14 @@ def test_infonce_to_goodness_of_fit(seed): infonce_values, model=cebra_model) assert isinstance(gof_array, np.ndarray) assert gof_array.shape == infonce_values.shape + assert np.allclose(gof_array, + (np.log(num_negatives) - infonce_values) * nats_to_bits) # Test with explicit batch_size and num_sessions - gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, - batch_size=128, - num_sessions=1) + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit( + 1.0, batch_size=batch_size, num_sessions=1) assert isinstance(gof, float) + assert np.isclose(gof, (np.log(batch_size) - 1.0) * nats_to_bits) # Test error cases with pytest.raises(ValueError, match="batch_size.*should not be provided"):