Skip to content

Commit 296d545

Browse files
author
gegejun
committed
fix batchsize when using BatchSampler
1 parent 8e805f9 commit 296d545

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

src/lightning/pytorch/utilities/data.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def _get_dataloader_init_args_and_kwargs(
169169
if was_wrapped:
170170
# if the dataloader was wrapped in a hook, only take arguments with default values
171171
# and assume user passes their kwargs correctly
172-
params.update({
173-
k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty
174-
})
172+
params.update(
173+
{k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty}
174+
)
175175
else:
176176
params.update(inspect.signature(DataLoader.__init__).parameters)
177177
params.pop("self", None)
@@ -332,6 +332,17 @@ def _dataloader_init_kwargs_resolve_sampler(
332332
"batch_size": 1,
333333
"drop_last": False,
334334
}
335+
if batch_sampler is not None and batch_sampler_cls is BatchSampler:
336+
# This is a PyTorch `BatchSampler` but maybe created by user, so batch_size and drop_last should be preserved
337+
batch_size = batch_sampler.batch_size
338+
drop_last = batch_sampler.drop_last if not is_predicting else False
339+
return {
340+
"sampler": sampler,
341+
"shuffle": False,
342+
"batch_sampler": None,
343+
"batch_size": batch_size,
344+
"drop_last": drop_last,
345+
}
335346

336347
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
337348

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from torch.utils.data import RandomSampler, BatchSampler
2+
from torch.utils.data.dataloader import DataLoader
3+
from torch.utils.data.distributed import DistributedSampler
4+
from lightning.pytorch import Callback, Trainer, seed_everything
5+
from tests_pytorch.helpers.runif import RunIf
6+
from lightning.pytorch.demos.boring_classes import (
7+
BoringModel,
8+
RandomDataset,
9+
)
10+
11+
12+
class DistribBatchSamplerCallback(Callback):
13+
def __init__(self, expected_batch_size=4):
14+
self.expected_batch_size = expected_batch_size
15+
16+
def on_train_start(self, trainer, pl_module):
17+
assert isinstance(trainer.train_dataloader.batch_sampler, DistributedSampler)
18+
assert trainer.train_dataloader.batch_size == self.expected_batch_size
19+
20+
21+
@RunIf(min_cuda_gpus=2, skip_windows=True)
22+
def test_dataloader_distributed_batch_sampler(tmp_path):
23+
"""Test DistributedSampler and it's arguments for DDP backend."""
24+
seed_everything(123)
25+
dataset = RandomDataset(32, 64)
26+
sampler = RandomSampler(dataset)
27+
batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False)
28+
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
29+
model = BoringModel()
30+
trainer = Trainer(
31+
accelerator="gpu",
32+
devices=[0, 1],
33+
num_nodes=1,
34+
strategy="ddp",
35+
default_root_dir=tmp_path,
36+
max_steps=1,
37+
callbacks=[DistribBatchSamplerCallback(expected_batch_size=4)],
38+
)
39+
trainer.fit(model, train_dataloaders=dataloader)

0 commit comments

Comments
 (0)