|
45 | 45 | from nerfstudio.data.datasets.base_dataset import InputDataset
|
46 | 46 | from nerfstudio.data.utils.data_utils import identity_collate
|
47 | 47 | from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
|
| 48 | +from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate |
48 | 49 | from nerfstudio.utils.misc import get_orig_class
|
49 | 50 | from nerfstudio.utils.rich_utils import CONSOLE
|
50 | 51 |
|
@@ -150,7 +151,7 @@ def __init__(
|
150 | 151 | assert len(self.train_unseen_cameras) > 0, "No data found in dataset"
|
151 | 152 | super().__init__()
|
152 | 153 |
|
153 |
| - def sample_train_cameras(self): |
| 154 | + def sample_train_cameras(self) -> List[int]: |
154 | 155 | """Return a list of camera indices sampled using the strategy specified by
|
155 | 156 | self.config.train_cameras_sampling_strategy"""
|
156 | 157 | num_train_cameras = len(self.train_dataset)
|
@@ -326,7 +327,7 @@ def setup_train(self):
|
326 | 327 | self.train_imagebatch_stream,
|
327 | 328 | batch_size=self.config.batch_size,
|
328 | 329 | num_workers=self.config.dataloader_num_workers,
|
329 |
| - collate_fn=identity_collate, |
| 330 | + collate_fn=nerfstudio_collate, |
330 | 331 | )
|
331 | 332 | self.iter_train_image_dataloader = iter(self.train_image_dataloader)
|
332 | 333 |
|
@@ -382,53 +383,30 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
|
382 | 383 | def get_train_rays_per_batch(self) -> int:
|
383 | 384 | """Returns resolution of the image returned from datamanager."""
|
384 | 385 | camera = self.train_dataset.cameras[0].reshape(())
|
385 |
| - return int(camera.width[0].item() * camera.height[0].item()) |
| 386 | + return int(camera.width[0].item() * camera.height[0].item()) * self.config.batch_size |
386 | 387 |
|
387 | 388 | def next_train(self, step: int) -> Tuple[Cameras, Dict]:
|
388 | 389 | """Returns the next training batch
|
389 | 390 | Returns a Camera instead of raybundle"""
|
390 | 391 |
|
391 | 392 | self.train_count += 1
|
392 | 393 | if self.config.cache_images == "disk":
|
393 |
| - output = next(self.iter_train_image_dataloader) |
394 |
| - print("Alex", output) |
395 |
| - camera, data = output[0] |
396 |
| - return camera, data |
| 394 | + cameras, data = next(self.iter_train_image_dataloader) |
| 395 | + return cameras, data |
397 | 396 |
|
398 |
| - image_indices = [] |
| 397 | + camera_indices = [] |
399 | 398 | for _ in range(self.config.batch_size):
|
400 | 399 | # Make sure to re-populate the unseen cameras list if we have exhausted it
|
401 | 400 | if len(self.train_unseen_cameras) == 0:
|
402 | 401 | self.train_unseen_cameras = self.sample_train_cameras()
|
403 |
| - image_indices.append(self.train_unseen_cameras.pop(0)) |
404 |
| - |
405 |
| - all_keys = self.cached_train[0].keys() |
406 |
| - |
407 |
| - data = {} |
408 |
| - for key in all_keys: |
409 |
| - if key == "image": |
410 |
| - data[key] = torch.stack([self.cached_train[i][key] for i in image_indices]).to(self.device) |
411 |
| - else: |
412 |
| - data[key] = [self.cached_train[i][key] for i in image_indices] |
413 |
| - |
414 |
| - cameras = Cameras( |
415 |
| - camera_to_worlds=self.train_cameras.camera_to_worlds[image_indices], |
416 |
| - fx=self.train_cameras.fx[image_indices], |
417 |
| - fy=self.train_cameras.fy[image_indices], |
418 |
| - cx=self.train_cameras.cx[image_indices], |
419 |
| - cy=self.train_cameras.cy[image_indices], |
420 |
| - width=self.train_cameras.width[image_indices], |
421 |
| - height=self.train_cameras.height[image_indices], |
422 |
| - camera_type=self.train_cameras.camera_type[image_indices], |
423 |
| - ).to(self.device) |
424 |
| - |
425 |
| - if self.train_cameras.distortion_params is not None: |
426 |
| - cameras.distortion_params = self.train_cameras.distortion_params[image_indices] |
427 |
| - |
428 |
| - if cameras.metadata is None: |
429 |
| - cameras.metadata = {} |
430 |
| - |
431 |
| - cameras.metadata["cam_idx"] = image_indices |
| 402 | + camera_indices.append(self.train_unseen_cameras.pop(0)) |
| 403 | + |
| 404 | + # NOTE: We're going to copy the data to make sure we don't mutate the cached dictionary. |
| 405 | + # This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335 |
| 406 | + data = nerfstudio_collate( |
| 407 | + [self.cached_train[i].copy() for i in camera_indices] |
| 408 | + ) # Note that this must happen before indexing cameras, as it can modify the cameras in the dataset during undistortion |
| 409 | + cameras = nerfstudio_collate([self.train_dataset.cameras[i : i + 1].to(self.device) for i in camera_indices]) |
432 | 410 |
|
433 | 411 | return cameras, data
|
434 | 412 |
|
|
0 commit comments