diff --git a/nerfstudio/cameras/camera_optimizers.py b/nerfstudio/cameras/camera_optimizers.py index 7b5f4ccbc8..682283a156 100644 --- a/nerfstudio/cameras/camera_optimizers.py +++ b/nerfstudio/cameras/camera_optimizers.py @@ -152,7 +152,7 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None: raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3] raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze() - def apply_to_camera(self, camera: Cameras) -> torch.Tensor: + def apply_to_camera(self, camera: Cameras) -> Float[Tensor, "b 3 4"]: """Apply the pose correction to the world-to-camera matrix in a Camera object""" if self.config.mode == "off": return camera.camera_to_worlds diff --git a/nerfstudio/data/datamanagers/full_images_datamanager.py b/nerfstudio/data/datamanagers/full_images_datamanager.py index 3ec06120cf..67c2c6b3b2 100644 --- a/nerfstudio/data/datamanagers/full_images_datamanager.py +++ b/nerfstudio/data/datamanagers/full_images_datamanager.py @@ -45,6 +45,7 @@ from nerfstudio.data.datasets.base_dataset import InputDataset from nerfstudio.data.utils.data_utils import identity_collate from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image +from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE @@ -89,6 +90,8 @@ class FullImageDatamanagerConfig(DataManagerConfig): More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader""" cache_compressed_images: bool = False """If True, cache raw image files as byte strings to RAM.""" + batch_size: int = 1 + """The batch size for the dataloader.""" class FullImageDatamanager(DataManager, Generic[TDataset]): @@ -148,7 +151,7 @@ def __init__( assert len(self.train_unseen_cameras) > 0, "No data found in dataset" super().__init__() - def sample_train_cameras(self): + def sample_train_cameras(self) -> List[int]: """Return a list of camera indices sampled using the strategy specified by self.config.train_cameras_sampling_strategy""" num_train_cameras = len(self.train_dataset) @@ -322,9 +325,9 @@ def setup_train(self): ) self.train_image_dataloader = DataLoader( self.train_imagebatch_stream, - batch_size=1, + batch_size=self.config.batch_size, num_workers=self.config.dataloader_num_workers, - collate_fn=identity_collate, + collate_fn=nerfstudio_collate, ) self.iter_train_image_dataloader = iter(self.train_image_dataloader) @@ -380,33 +383,35 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: def get_train_rays_per_batch(self) -> int: """Returns resolution of the image returned from datamanager.""" camera = self.train_dataset.cameras[0].reshape(()) - return int(camera.width[0].item() * camera.height[0].item()) + return int(camera.width[0].item() * camera.height[0].item()) * self.config.batch_size def next_train(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next training batch Returns a Camera instead of raybundle""" + self.train_count += 1 if self.config.cache_images == "disk": - camera, data = next(self.iter_train_image_dataloader)[0] - return camera, data + cameras, data = next(self.iter_train_image_dataloader) + return cameras, data - image_idx = self.train_unseen_cameras.pop(0) - # Make sure to re-populate the unseen cameras list if we have exhausted it - if len(self.train_unseen_cameras) == 0: - self.train_unseen_cameras = self.sample_train_cameras() + camera_indices = [] + for _ in range(self.config.batch_size): + # Make sure to re-populate the unseen cameras list if we have exhausted it + if len(self.train_unseen_cameras) == 0: + self.train_unseen_cameras = self.sample_train_cameras() + camera_indices.append(self.train_unseen_cameras.pop(0)) - data = self.cached_train[image_idx] - # We're going to copy to make sure we don't mutate the cached dictionary. + # NOTE: We're going to copy the data to make sure we don't mutate the cached dictionary. # This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335 - data = data.copy() - data["image"] = data["image"].to(self.device) - - assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension" - camera = self.train_cameras[image_idx : image_idx + 1].to(self.device) - if camera.metadata is None: - camera.metadata = {} - camera.metadata["cam_idx"] = image_idx - return camera, data + data = nerfstudio_collate( + [self.cached_train[i].copy() for i in camera_indices] + ) # Note that this must happen before indexing cameras, as it can modify the cameras in the dataset during undistortion + cameras = nerfstudio_collate([self.train_dataset.cameras[i : i + 1].to(self.device) for i in camera_indices]) + + if cameras.metadata is None: + cameras.metadata = {} + cameras.metadata["cam_idx"] = torch.tensor(camera_indices, device=self.device, dtype=torch.long) + return cameras, data def next_eval(self, step: int) -> Tuple[Cameras, Dict]: """Returns the next evaluation batch diff --git a/nerfstudio/data/utils/nerfstudio_collate.py b/nerfstudio/data/utils/nerfstudio_collate.py index c2cca9a742..9c57b12537 100644 --- a/nerfstudio/data/utils/nerfstudio_collate.py +++ b/nerfstudio/data/utils/nerfstudio_collate.py @@ -98,7 +98,7 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) - storage = elem.storage()._new_shared(numel, device=elem.device) + storage = elem.untyped_storage()._new_shared(numel, device=str(elem.device)) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) elif elem_type.__module__ == "numpy" and elem_type.__name__ not in ("str_", "string_"): @@ -179,7 +179,9 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N # Create metadata dictionary if batch[0].metadata is not None: - metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()} + metadata = { + key: op([torch.tensor([cam.metadata[key]]) for cam in batch], dim=0) for key in batch[0].metadata.keys() + } else: metadata = None diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 136a3168e6..0282b06898 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -46,20 +46,28 @@ from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases -def resize_image(image: torch.Tensor, d: int): +def resize_image(image: torch.Tensor, d: int) -> torch.Tensor: """ Downscale images using the same 'area' method in opencv - :param image shape [H, W, C] + :param image shape [B, H, W, C] :param d downscale factor (must be 2, 4, 8, etc.) - return downscaled image in shape [H//d, W//d, C] + return downscaled image in shape [B, H//d, W//d, C] """ import torch.nn.functional as tf - image = image.to(torch.float32) weight = (1.0 / (d * d)) * torch.ones((1, 1, d, d), dtype=torch.float32, device=image.device) - return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0) + + B, H, W, C = image.shape + image = image.permute(0, 3, 1, 2) # [B, C, H, W] + image = image.reshape(B * C, 1, H, W) # Combine batch and channel dimensions for Conv2D + + downscaled = tf.conv2d(image, weight, stride=d) + downscaled = downscaled.reshape(B, C, downscaled.shape[-2], downscaled.shape[-1]) + downscaled = downscaled.permute(0, 2, 3, 1) # [B, H//d, W//d, C] + + return downscaled @torch_compile() @@ -465,7 +473,11 @@ def _get_background_color(self): raise ValueError(f"Unknown background color {self.config.background_color}") return background - def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int) -> torch.Tensor: + def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idxs: torch.Tensor, H: int, W: int) -> torch.Tensor: + """ + rgb: [B, H, W, 3] + cam_idxs: [B] + """ # make xy grid grid_y, grid_x = torch.meshgrid( torch.linspace(0, 1.0, H, device=self.device), @@ -473,41 +485,39 @@ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int) indexing="ij", ) grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) - out = slice( bil_grids=self.bil_grids, - rgb=rgb, - xy=grid_xy, - grid_idx=torch.tensor(cam_idx, device=self.device, dtype=torch.long), + rgb=rgb, # Process the entire batch in parallel + xy=grid_xy.expand(rgb.shape[0], -1, -1, -1), # Expand grid_xy to match batch size + grid_idx=cam_idxs.unsqueeze(-1), ) - return out["rgb"] + return out["rgb"] # Return the processed RGB directly - def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: - """Takes in a camera and returns a dictionary of outputs. + def get_outputs(self, cameras: Cameras) -> Dict[str, Union[torch.Tensor, List]]: + """Takes in cameras and returns a dictionary of outputs. Args: - camera: The camera(s) for which output images are rendered. It should have + cameras: The camera(s) for which output images are rendered. It should have all the needed information to compute the outputs. Returns: Outputs of model. (ie. rendered colors) """ - if not isinstance(camera, Cameras): + if not isinstance(cameras, Cameras): print("Called get_outputs with not a camera") return {} if self.training: - assert camera.shape[0] == 1, "Only one camera at a time" - optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera) + optimized_camera_to_world = self.camera_optimizer.apply_to_camera(cameras) else: - optimized_camera_to_world = camera.camera_to_worlds + optimized_camera_to_world = cameras.camera_to_worlds # cropping if self.crop_box is not None and not self.training: crop_ids = self.crop_box.within(self.means).squeeze() if crop_ids.sum() == 0: return self.get_empty_outputs( - int(camera.width.item()), int(camera.height.item()), self.background_color + int(cameras.width.item()), int(cameras.height.item()), self.background_color ) else: crop_ids = None @@ -530,12 +540,16 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1) camera_scale_fac = self._get_downscale_factor() - camera.rescale_output_resolution(1 / camera_scale_fac) - viewmat = get_viewmat(optimized_camera_to_world) - K = camera.get_intrinsics_matrices().cuda() - W, H = int(camera.width.item()), int(camera.height.item()) + cameras.rescale_output_resolution(1 / camera_scale_fac) + viewmats = get_viewmat(optimized_camera_to_world) + Ks = cameras.get_intrinsics_matrices().cuda() + + W, H = ( + int(cameras.width[0]), + int(cameras.height[0]), + ) # assume all cameras have the same resolution self.last_size = (H, W) - camera.rescale_output_resolution(camera_scale_fac) # type: ignore + cameras.rescale_output_resolution(camera_scale_fac) # type: ignore # apply the compensation of screen space blurring to gaussians if self.config.rasterize_mode not in ["antialiased", "classic"]: @@ -558,8 +572,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: scales=torch.exp(scales_crop), opacities=torch.sigmoid(opacities_crop).squeeze(-1), colors=colors_crop, - viewmats=viewmat, # [1, 4, 4] - Ks=K, # [1, 3, 3] + viewmats=viewmats, # [B, 4, 4] + Ks=Ks, # [B, 3, 3] width=W, height=H, packed=False, @@ -585,24 +599,30 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: # apply bilateral grid if self.config.use_bilateral_grid and self.training: - if camera.metadata is not None and "cam_idx" in camera.metadata: - rgb = self._apply_bilateral_grid(rgb, camera.metadata["cam_idx"], H, W) + if cameras.metadata is not None and "cam_idx" in cameras.metadata: + rgb = self._apply_bilateral_grid(rgb, cameras.metadata["cam_idx"], H, W) + else: + raise ValueError("Camera index not found in metadata, bilateral grid cannot be applied.") if render_mode == "RGB+ED": depth_im = render[:, ..., 3:4] - depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0) + depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()) else: depth_im = None if background.shape[0] == 3 and not self.training: background = background.expand(H, W, 3) - return { - "rgb": rgb.squeeze(0), # type: ignore - "depth": depth_im, # type: ignore - "accumulation": alpha.squeeze(0), # type: ignore - "background": background, # type: ignore - } # type: ignore + outputs = { + "rgb": rgb, + "depth": depth_im, + "accumulation": alpha, + "background": background, + } + + if self.training: + return outputs + return {k: v.squeeze(0) if k != "background" else v for k, v in outputs.items()} def get_gt_img(self, image: torch.Tensor): """Compute groundtruth image with iteration dependent downscale factor for evaluation purpose @@ -622,8 +642,8 @@ def composite_with_background(self, image, background) -> torch.Tensor: image: the image to composite background: the background color """ - if image.shape[2] == 4: - alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3)) + if image.shape[-1] == 4: + alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 1, 3)) return alpha * image[..., :3] + (1 - alpha) * background else: return image @@ -671,7 +691,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te pred_img = pred_img * mask Ll1 = torch.abs(gt_img - pred_img).mean() - simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...]) + simloss = 1 - self.ssim(gt_img.permute(0, 3, 1, 2), pred_img.permute(0, 3, 1, 2)) if self.config.use_scale_regularization and self.step % 10 == 0: scale_exp = torch.exp(self.scales) scale_reg = (