Skip to content

Commit 7534956

Browse files
cleanup + make work with dataloader refactor
1 parent 1aa3d09 commit 7534956

File tree

2 files changed

+19
-39
lines changed

2 files changed

+19
-39
lines changed

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from nerfstudio.data.datasets.base_dataset import InputDataset
4646
from nerfstudio.data.utils.data_utils import identity_collate
4747
from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
48+
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
4849
from nerfstudio.utils.misc import get_orig_class
4950
from nerfstudio.utils.rich_utils import CONSOLE
5051

@@ -150,7 +151,7 @@ def __init__(
150151
assert len(self.train_unseen_cameras) > 0, "No data found in dataset"
151152
super().__init__()
152153

153-
def sample_train_cameras(self):
154+
def sample_train_cameras(self) -> List[int]:
154155
"""Return a list of camera indices sampled using the strategy specified by
155156
self.config.train_cameras_sampling_strategy"""
156157
num_train_cameras = len(self.train_dataset)
@@ -326,7 +327,7 @@ def setup_train(self):
326327
self.train_imagebatch_stream,
327328
batch_size=self.config.batch_size,
328329
num_workers=self.config.dataloader_num_workers,
329-
collate_fn=identity_collate,
330+
collate_fn=nerfstudio_collate,
330331
)
331332
self.iter_train_image_dataloader = iter(self.train_image_dataloader)
332333

@@ -382,53 +383,30 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
382383
def get_train_rays_per_batch(self) -> int:
383384
"""Returns resolution of the image returned from datamanager."""
384385
camera = self.train_dataset.cameras[0].reshape(())
385-
return int(camera.width[0].item() * camera.height[0].item())
386+
return int(camera.width[0].item() * camera.height[0].item()) * self.config.batch_size
386387

387388
def next_train(self, step: int) -> Tuple[Cameras, Dict]:
388389
"""Returns the next training batch
389390
Returns a Camera instead of raybundle"""
390391

391392
self.train_count += 1
392393
if self.config.cache_images == "disk":
393-
output = next(self.iter_train_image_dataloader)
394-
print("Alex", output)
395-
camera, data = output[0]
396-
return camera, data
394+
cameras, data = next(self.iter_train_image_dataloader)
395+
return cameras, data
397396

398-
image_indices = []
397+
camera_indices = []
399398
for _ in range(self.config.batch_size):
400399
# Make sure to re-populate the unseen cameras list if we have exhausted it
401400
if len(self.train_unseen_cameras) == 0:
402401
self.train_unseen_cameras = self.sample_train_cameras()
403-
image_indices.append(self.train_unseen_cameras.pop(0))
404-
405-
all_keys = self.cached_train[0].keys()
406-
407-
data = {}
408-
for key in all_keys:
409-
if key == "image":
410-
data[key] = torch.stack([self.cached_train[i][key] for i in image_indices]).to(self.device)
411-
else:
412-
data[key] = [self.cached_train[i][key] for i in image_indices]
413-
414-
cameras = Cameras(
415-
camera_to_worlds=self.train_cameras.camera_to_worlds[image_indices],
416-
fx=self.train_cameras.fx[image_indices],
417-
fy=self.train_cameras.fy[image_indices],
418-
cx=self.train_cameras.cx[image_indices],
419-
cy=self.train_cameras.cy[image_indices],
420-
width=self.train_cameras.width[image_indices],
421-
height=self.train_cameras.height[image_indices],
422-
camera_type=self.train_cameras.camera_type[image_indices],
423-
).to(self.device)
424-
425-
if self.train_cameras.distortion_params is not None:
426-
cameras.distortion_params = self.train_cameras.distortion_params[image_indices]
427-
428-
if cameras.metadata is None:
429-
cameras.metadata = {}
430-
431-
cameras.metadata["cam_idx"] = image_indices
402+
camera_indices.append(self.train_unseen_cameras.pop(0))
403+
404+
# NOTE: We're going to copy the data to make sure we don't mutate the cached dictionary.
405+
# This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335
406+
data = nerfstudio_collate(
407+
[self.cached_train[i].copy() for i in camera_indices]
408+
) # Note that this must happen before indexing cameras, as it can modify the cameras in the dataset during undistortion
409+
cameras = nerfstudio_collate([self.train_dataset.cameras[i : i + 1].to(self.device) for i in camera_indices])
432410

433411
return cameras, data
434412

nerfstudio/data/utils/nerfstudio_collate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N
9898
# If we're in a background process, concatenate directly into a
9999
# shared memory tensor to avoid an extra copy
100100
numel = sum(x.numel() for x in batch)
101-
storage = elem.storage()._new_shared(numel, device=elem.device)
101+
storage = elem.untyped_storage()._new_shared(numel, device=elem.device)
102102
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
103103
return torch.stack(batch, 0, out=out)
104104
elif elem_type.__module__ == "numpy" and elem_type.__name__ not in ("str_", "string_"):
@@ -179,7 +179,9 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N
179179

180180
# Create metadata dictionary
181181
if batch[0].metadata is not None:
182-
metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()}
182+
metadata = {
183+
key: op([torch.tensor([cam.metadata[key]]) for cam in batch], dim=0) for key in batch[0].metadata.keys()
184+
}
183185
else:
184186
metadata = None
185187

0 commit comments

Comments
 (0)