From 248b96249137c93637ee6be8421295380b4c6620 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Sun, 29 Oct 2023 17:27:30 +0100 Subject: [PATCH 1/6] add positive sampling options for MixedDataLoader --- cebra/data/single_session.py | 59 +++++++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 7802b787..eb093449 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -268,27 +268,47 @@ class MixedDataLoader(cebra_data.Loader): 1. Positive pairs always share their discrete variable. 2. Positive pairs are drawn only based on their conditional, not discrete variable. + + Args: + conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` + time_offset (int): :py:attr:`cebra.CEBRA.time_offsets` + positive_sampling (str): either "discrete_variable" (default) or "conditional" + discrete_sampling_prior (str): either "empirical" (default) or "uniform" """ conditional: str = dataclasses.field(default="time_delta") time_offset: int = dataclasses.field(default=10) + positive_sampling: str = dataclasses.field(default="discrete_variable") + discrete_sampling_prior: str = dataclasses.field(default="uniform") @property - def dindex(self): - # TODO(stes) rename to discrete_index + def discrete_index(self): return self.dataset.discrete_index @property - def cindex(self): - # TODO(stes) rename to continuous_index + def continuous_index(self): return self.dataset.continuous_index def __post_init__(self): super().__post_init__() - self.distribution = cebra.distributions.MixedTimeDeltaDistribution( - discrete=self.dindex, - continuous=self.cindex, - time_delta=self.time_offset) + if self.positive_sampling == "conditional": + self.distribution = cebra.distributions.MixedTimeDeltaDistribution( + discrete=self.discrete_index, + continuous=self.continuous_index, + time_delta=self.time_offset) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical": + self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform": + self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index) + elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]: + raise ValueError( + f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but " + f"only accept 'uniform' or 'empirical' as potential values.") + else: + raise ValueError( + f"Invalid positive sampling mode: " + f"{self.positive_sampling} valid options are " + f"'conditional' or 'discrete_variable'.") def get_indices(self, num_samples: int) -> BatchIndex: """Samples indices for reference, positive and negative examples. @@ -313,12 +333,23 @@ def get_indices(self, num_samples: int) -> BatchIndex: class. - Sample the negatives with matching discrete variable """ - reference_idx = self.distribution.sample_prior(num_samples) - return BatchIndex( - reference=reference_idx, - negative=self.distribution.sample_prior(num_samples), - positive=self.distribution.sample_conditional(reference_idx), - ) + if self.positive_sampling == "conditional": + reference_idx = self.distribution.sample_prior(num_samples) + return BatchIndex( + reference=reference_idx, + negative=self.distribution.sample_prior(num_samples), + positive=self.distribution.sample_conditional(reference_idx), + ) + else: + # taken from the DiscreteDataLoader get_indices function + reference_idx = self.distribution.sample_prior(num_samples * 2) + negative_idx = reference_idx[num_samples:] + reference_idx = reference_idx[:num_samples] + reference = self.discrete_index[reference_idx] + positive_idx = self.distribution.sample_conditional(reference) + return BatchIndex(reference=reference_idx, + positive=positive_idx, + negative=negative_idx) @dataclasses.dataclass From dc1c77cf577ed61af135ce2ccd44e6901a2d04a8 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 17:27:21 +0100 Subject: [PATCH 2/6] add deprecation warning for cindex and dindex --- cebra/data/single_session.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index eb093449..1fddc743 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -281,10 +281,22 @@ class MixedDataLoader(cebra_data.Loader): positive_sampling: str = dataclasses.field(default="discrete_variable") discrete_sampling_prior: str = dataclasses.field(default="uniform") + @property + def dindex(self): + warnings.warn("dindex is deprecated. Use discrete_index instead.", + DeprecationWarning) + return self.dataset.discrete_index + @property def discrete_index(self): return self.dataset.discrete_index + @property + def cindex(self): + warnings.warn("cindex is deprecated. Use continuous_index instead.", + DeprecationWarning) + return self.dataset.continuous_index + @property def continuous_index(self): return self.dataset.continuous_index From 6763dc1b57cd747eb7f1c826de81a91a70319470 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 17:31:48 +0100 Subject: [PATCH 3/6] add test for MixedDataLoader including additional keywords --- tests/test_loader.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a7..e2d819cf 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark): benchmark(load_speed) +@parametrize_device +@pytest.mark.parametrize( + "conditional, positive_sampling, discrete_sampling_prior", + [ + ("time", "discrete_variable", "empirical"), + ("time", "conditional", "empirical"), + ("time", "discrete_variable", "uniform"), + ("time", "conditional", "uniform"), + ("time_delta", "discrete_variable", "empirical"), + ("time_delta", "conditional", "empirical"), + ("time_delta", "discrete_variable", "uniform"), + ("time_delta", "conditional", "uniform"), + ], +) +def test_mixed( + conditional, positive_sampling, discrete_sampling_prior, device, benchmark +): + dataset = RandomDataset(N=100, d=5, device=device) + loader = cebra.data.MixedDataLoader( + dataset=dataset, + num_steps=10, + batch_size=8, + conditional=conditional, + positive_sampling=positive_sampling, + discrete_sampling_prior=discrete_sampling_prior, + ) + _assert_dataset_on_correct_device(loader, device) + load_speed = LoadSpeed(loader) + benchmark(load_speed) + + def _check_attributes(obj, is_list=False): if is_list: for obj_ in obj: From 9835d45d6236beff86c9c5e4e43fbc50cb31d9b4 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 18:06:07 +0100 Subject: [PATCH 4/6] add improved docstring description --- cebra/data/single_session.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 1fddc743..61cebb7e 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -265,9 +265,16 @@ class MixedDataLoader(cebra_data.Loader): Sampling can be configured in different modes: - 1. Positive pairs always share their discrete variable. + 1. Positive pairs always share their discrete variable (positive_sampling = "discrete_variable"). 2. Positive pairs are drawn only based on their conditional, - not discrete variable. + not discrete variable (positive_sampling = "conditional"). + + When using the discrete variable, the prior distribution can either be uniform + (discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical"). + + Based on the selection of those parameters, the :py:class:`cebra.distributions.MixedTimeDeltaDistribution`, + :py:class:`cebra.distributions.DiscreteEmpirical`, or :py:class:`cebra.distributions.DiscreteUniform` + distributions are used for sampling. Args: conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` From 8dee8a05528939008c780516bc9e19f5713a8fb6 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Wed, 13 Dec 2023 18:21:55 +0100 Subject: [PATCH 5/6] fix docstring sphinx link --- cebra/data/single_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 61cebb7e..993078fa 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -272,8 +272,8 @@ class MixedDataLoader(cebra_data.Loader): When using the discrete variable, the prior distribution can either be uniform (discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical"). - Based on the selection of those parameters, the :py:class:`cebra.distributions.MixedTimeDeltaDistribution`, - :py:class:`cebra.distributions.DiscreteEmpirical`, or :py:class:`cebra.distributions.DiscreteUniform` + Based on the selection of those parameters, the :py:class:`cebra.distributions.mixed.MixedTimeDeltaDistribution`, + :py:class:`cebra.distributions.discrete.DiscreteEmpirical`, or :py:class:`cebra.distributions.discrete.DiscreteUniform` distributions are used for sampling. Args: From 0326fb9db66eb009d7b53f0423a002ef91a652b2 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Tue, 18 Feb 2025 10:53:11 +0100 Subject: [PATCH 6/6] Update cebra/data/single_session.py Co-authored-by: Steffen Schneider --- cebra/data/single_session.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 993078fa..a04cd2f4 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -360,15 +360,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: positive=self.distribution.sample_conditional(reference_idx), ) else: - # taken from the DiscreteDataLoader get_indices function - reference_idx = self.distribution.sample_prior(num_samples * 2) - negative_idx = reference_idx[num_samples:] - reference_idx = reference_idx[:num_samples] - reference = self.discrete_index[reference_idx] - positive_idx = self.distribution.sample_conditional(reference) - return BatchIndex(reference=reference_idx, - positive=positive_idx, - negative=negative_idx) + return self.distribution.get_indices(num_samples) @dataclasses.dataclass