-
Notifications
You must be signed in to change notification settings - Fork 89
Decouple batch size and number of negatives #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
723bcfb
dbabb6e
540b006
07212f2
6c2d559
0dba5fc
e259e45
f2af3b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
"""Base classes for datasets and loaders.""" | ||
|
||
import abc | ||
from typing import Iterator | ||
|
||
import literate_dataclasses as dataclasses | ||
import torch | ||
|
@@ -239,6 +240,12 @@ class Loader(abc.ABC, cebra.io.HasDevice): | |
batch_size: int = dataclasses.field(default=None, | ||
doc="""The total batch size.""") | ||
|
||
num_negatives: int = dataclasses.field( | ||
default=None, | ||
doc=("The number of negative samples to draw for each reference. " | ||
"If not specified, the batch size is used."), | ||
) | ||
|
||
def __post_init__(self): | ||
if self.num_steps is None or self.num_steps <= 0: | ||
raise ValueError( | ||
|
@@ -248,28 +255,41 @@ def __post_init__(self): | |
raise ValueError( | ||
f"Batch size has to be None, or a non-negative value. Got {self.batch_size}." | ||
) | ||
if self.num_negatives is not None and self.num_negatives <= 0: | ||
raise ValueError( | ||
f"Number of negatives has to be None, or a non-negative value. Got {self.num_negatives}." | ||
) | ||
|
||
if self.num_negatives is None: | ||
self.num_negatives = self.batch_size | ||
|
||
def __len__(self): | ||
"""The number of batches returned when calling as an iterator.""" | ||
return self.num_steps | ||
|
||
def __iter__(self) -> Batch: | ||
def __iter__(self) -> Iterator[Batch]: | ||
for _ in range(len(self)): | ||
index = self.get_indices(num_samples=self.batch_size) | ||
index = self.get_indices() | ||
yield self.dataset.load_batch(index) | ||
|
||
@abc.abstractmethod | ||
def get_indices(self, num_samples: int): | ||
def get_indices(self, num_samples: int = None): | ||
"""Sample and return the specified number of indices. | ||
|
||
The elements of the returned `BatchIndex` will be used to index the | ||
`dataset` of this data loader. | ||
|
||
Args: | ||
num_samples: The size of each of the reference, positive and | ||
negative samples. | ||
num_samples: Deprecated. Use ``batch_size`` on the instance level | ||
instead. | ||
|
||
Returns: | ||
batch indices for the reference, positive and negative sample. | ||
|
||
Note: | ||
From version 0.7.0 onwards, specifying the ``num_samples`` | ||
directly is deprecated and will be removed in version 0.8.0. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The method signature still accepts 'num_samples' parameter even though it's deprecated. Consider removing this parameter entirely or making it keyword-only to prevent accidental usage. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
Please set ``batch_size`` and ``num_negatives`` on the instance | ||
level instead. | ||
""" | ||
raise NotImplementedError() |
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -27,6 +27,7 @@ | |||||||||||||||||
|
||||||||||||||||||
import abc | ||||||||||||||||||
import warnings | ||||||||||||||||||
from typing import Iterator | ||||||||||||||||||
|
||||||||||||||||||
import literate_dataclasses as dataclasses | ||||||||||||||||||
import torch | ||||||||||||||||||
|
@@ -138,7 +139,7 @@ def _init_distribution(self): | |||||||||||||||||
f"Invalid choice of prior distribution. Got '{self.prior}', but " | ||||||||||||||||||
f"only accept 'uniform' or 'empirical' as potential values.") | ||||||||||||||||||
|
||||||||||||||||||
def get_indices(self, num_samples: int) -> BatchIndex: | ||||||||||||||||||
def get_indices(self) -> BatchIndex: | ||||||||||||||||||
"""Samples indices for reference, positive and negative examples. | ||||||||||||||||||
|
||||||||||||||||||
The reference samples will be sampled from the empirical or uniform prior | ||||||||||||||||||
|
@@ -154,13 +155,15 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
Args: | ||||||||||||||||||
num_samples: The number of samples (batch size) of the returned | ||||||||||||||||||
:py:class:`cebra.data.datatypes.BatchIndex`. | ||||||||||||||||||
num_negatives: The number of negative samples. If None, defaults to num_samples. | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
Indices for reference, positive and negatives samples. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring mentions 'num_samples' but this parameter no longer exists in the function signature. The docstring should be updated to reflect that num_negatives is now an instance attribute.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||
""" | ||||||||||||||||||
reference_idx = self.distribution.sample_prior(num_samples * 2) | ||||||||||||||||||
negative_idx = reference_idx[num_samples:] | ||||||||||||||||||
reference_idx = reference_idx[:num_samples] | ||||||||||||||||||
reference_idx = self.distribution.sample_prior(self.batch_size + | ||||||||||||||||||
self.num_negatives) | ||||||||||||||||||
negative_idx = reference_idx[self.batch_size:] | ||||||||||||||||||
reference_idx = reference_idx[:self.batch_size] | ||||||||||||||||||
reference = self.index[reference_idx] | ||||||||||||||||||
positive_idx = self.distribution.sample_conditional(reference) | ||||||||||||||||||
return BatchIndex(reference=reference_idx, | ||||||||||||||||||
|
@@ -246,7 +249,7 @@ def _init_distribution(self): | |||||||||||||||||
else: | ||||||||||||||||||
raise ValueError(self.conditional) | ||||||||||||||||||
|
||||||||||||||||||
def get_indices(self, num_samples: int) -> BatchIndex: | ||||||||||||||||||
def get_indices(self) -> BatchIndex: | ||||||||||||||||||
"""Samples indices for reference, positive and negative examples. | ||||||||||||||||||
|
||||||||||||||||||
The reference and negative samples will be sampled uniformly from | ||||||||||||||||||
|
@@ -262,9 +265,10 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
Returns: | ||||||||||||||||||
Indices for reference, positive and negatives samples. | ||||||||||||||||||
""" | ||||||||||||||||||
reference_idx = self.distribution.sample_prior(num_samples * 2) | ||||||||||||||||||
negative_idx = reference_idx[num_samples:] | ||||||||||||||||||
reference_idx = reference_idx[:num_samples] | ||||||||||||||||||
reference_idx = self.distribution.sample_prior(self.batch_size + | ||||||||||||||||||
self.num_negatives) | ||||||||||||||||||
negative_idx = reference_idx[self.batch_size:] | ||||||||||||||||||
reference_idx = reference_idx[:self.batch_size] | ||||||||||||||||||
positive_idx = self.distribution.sample_conditional(reference_idx) | ||||||||||||||||||
return BatchIndex(reference=reference_idx, | ||||||||||||||||||
positive=positive_idx, | ||||||||||||||||||
|
@@ -305,7 +309,7 @@ def __post_init__(self): | |||||||||||||||||
continuous=self.cindex, | ||||||||||||||||||
time_delta=self.time_offset) | ||||||||||||||||||
|
||||||||||||||||||
def get_indices(self, num_samples: int) -> BatchIndex: | ||||||||||||||||||
def get_indices(self) -> BatchIndex: | ||||||||||||||||||
"""Samples indices for reference, positive and negative examples. | ||||||||||||||||||
|
||||||||||||||||||
The reference and negative samples will be sampled uniformly from | ||||||||||||||||||
|
@@ -319,6 +323,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
Args: | ||||||||||||||||||
num_samples: The number of samples (batch size) of the returned | ||||||||||||||||||
:py:class:`cebra.data.datatypes.BatchIndex`. | ||||||||||||||||||
num_negatives: The number of negative samples. If None, defaults to num_samples. | ||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
Indices for reference, positive and negatives samples. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, the docstring references 'num_samples' which no longer exists. This should be updated to reflect the current implementation.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||
|
@@ -328,10 +333,13 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
class. | ||||||||||||||||||
- Sample the negatives with matching discrete variable | ||||||||||||||||||
""" | ||||||||||||||||||
reference_idx = self.distribution.sample_prior(num_samples) | ||||||||||||||||||
reference_idx = self.distribution.sample_prior(self.batch_size + | ||||||||||||||||||
self.num_negatives) | ||||||||||||||||||
negative_idx = reference_idx[self.batch_size:] | ||||||||||||||||||
reference_idx = reference_idx[:self.batch_size] | ||||||||||||||||||
return BatchIndex( | ||||||||||||||||||
reference=reference_idx, | ||||||||||||||||||
negative=self.distribution.sample_prior(num_samples), | ||||||||||||||||||
negative=negative_idx, | ||||||||||||||||||
positive=self.distribution.sample_conditional(reference_idx), | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
@@ -421,11 +429,11 @@ def _init_time_distribution(self): | |||||||||||||||||
else: | ||||||||||||||||||
raise ValueError | ||||||||||||||||||
|
||||||||||||||||||
def get_indices(self, num_samples: int) -> BatchIndex: | ||||||||||||||||||
def get_indices(self) -> BatchIndex: | ||||||||||||||||||
"""Samples indices for reference, positive and negative examples. | ||||||||||||||||||
|
||||||||||||||||||
The reference and negative samples will be sampled uniformly from | ||||||||||||||||||
all available time steps, and a total of ``2*num_samples`` will be | ||||||||||||||||||
all available time steps, and a total of ``num_samples + num_negatives`` will be | ||||||||||||||||||
returned for both. | ||||||||||||||||||
|
||||||||||||||||||
For the positive samples, ``num_samples`` are sampled according to the | ||||||||||||||||||
|
@@ -436,6 +444,7 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
Args: | ||||||||||||||||||
num_samples: The number of samples (batch size) of the returned | ||||||||||||||||||
:py:class:`cebra.data.datatypes.BatchIndex`. | ||||||||||||||||||
num_negatives: The number of negative samples. If None, defaults to num_samples. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another instance where the docstring incorrectly references 'num_samples'. This should be updated to reflect that num_negatives is an instance attribute. Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||
|
||||||||||||||||||
Returns: | ||||||||||||||||||
Indices for reference, positive and negatives samples. | ||||||||||||||||||
|
@@ -444,9 +453,10 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||||||||||||||||
Add the ``empirical`` vs. ``discrete`` sampling modes to this | ||||||||||||||||||
class. | ||||||||||||||||||
""" | ||||||||||||||||||
reference_idx = self.time_distribution.sample_prior(num_samples * 2) | ||||||||||||||||||
negative_idx = reference_idx[num_samples:] | ||||||||||||||||||
reference_idx = reference_idx[:num_samples] | ||||||||||||||||||
reference_idx = self.time_distribution.sample_prior(self.batch_size + | ||||||||||||||||||
self.num_negatives) | ||||||||||||||||||
negative_idx = reference_idx[self.batch_size:] | ||||||||||||||||||
reference_idx = reference_idx[:self.batch_size] | ||||||||||||||||||
behavior_positive_idx = self.behavior_distribution.sample_conditional( | ||||||||||||||||||
reference_idx) | ||||||||||||||||||
time_positive_idx = self.time_distribution.sample_conditional( | ||||||||||||||||||
|
@@ -464,13 +474,18 @@ class FullDataLoader(ContinuousDataLoader): | |||||||||||||||||
|
||||||||||||||||||
def __post_init__(self): | ||||||||||||||||||
super().__post_init__() | ||||||||||||||||||
self.batch_size = None | ||||||||||||||||||
|
||||||||||||||||||
if self.batch_size is not None: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition check for Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||
raise ValueError("Batch size cannot be set for FullDataLoader.") | ||||||||||||||||||
if self.num_negatives is not None: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"Number of negatives cannot be set for FullDataLoader.") | ||||||||||||||||||
|
||||||||||||||||||
@property | ||||||||||||||||||
def offset(self): | ||||||||||||||||||
return self.dataset.offset | ||||||||||||||||||
|
||||||||||||||||||
def get_indices(self, num_samples=None) -> BatchIndex: | ||||||||||||||||||
def get_indices(self) -> BatchIndex: | ||||||||||||||||||
"""Samples indices for reference, positive and negative examples. | ||||||||||||||||||
|
||||||||||||||||||
The reference indices are all available (valid, according to the | ||||||||||||||||||
|
@@ -490,7 +505,6 @@ def get_indices(self, num_samples=None) -> BatchIndex: | |||||||||||||||||
Add the ``empirical`` vs. ``discrete`` sampling modes to this | ||||||||||||||||||
class. | ||||||||||||||||||
""" | ||||||||||||||||||
assert num_samples is None | ||||||||||||||||||
|
||||||||||||||||||
reference_idx = torch.arange( | ||||||||||||||||||
self.offset.left, | ||||||||||||||||||
|
@@ -504,7 +518,6 @@ def get_indices(self, num_samples=None) -> BatchIndex: | |||||||||||||||||
positive=positive_idx, | ||||||||||||||||||
negative=negative_idx) | ||||||||||||||||||
|
||||||||||||||||||
def __iter__(self): | ||||||||||||||||||
def __iter__(self) -> Iterator[BatchIndex]: | ||||||||||||||||||
for _ in range(len(self)): | ||||||||||||||||||
index = self.get_indices(num_samples=self.batch_size) | ||||||||||||||||||
yield index | ||||||||||||||||||
yield self.get_indices() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
get_indices
method signature change introduces a potential breaking change by makingnum_samples
optional with a default ofNone
. While the deprecation is documented, the method should handle the case wherenum_samples
is passed but shouldn't be used, potentially issuing a deprecation warning to guide users toward the new API.Copilot uses AI. Check for mistakes.