Skip to content

Commit abff91c

Browse files
Allow for >1 batch size in splatfacto
1 parent 189328e commit abff91c

File tree

3 files changed

+86
-49
lines changed

3 files changed

+86
-49
lines changed

nerfstudio/cameras/camera_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
152152
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
153153
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()
154154

155-
def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
155+
def apply_to_camera(self, camera: Cameras) -> Float[Tensor, "b 3 4"]:
156156
"""Apply the pose correction to the world-to-camera matrix in a Camera object"""
157157
if self.config.mode == "off":
158158
return camera.camera_to_worlds

nerfstudio/data/datamanagers/full_images_datamanager.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class FullImageDatamanagerConfig(DataManagerConfig):
7979
fps_reset_every: int = 100
8080
"""The number of iterations before one resets fps sampler repeatly, which is essentially drawing fps_reset_every
8181
samples from the pool of all training cameras without replacement before a new round of sampling starts."""
82+
batch_size: int = 1
83+
"""The batch size for the dataloader."""
8284

8385

8486
class FullImageDatamanager(DataManager, Generic[TDataset]):
@@ -336,31 +338,51 @@ def get_train_rays_per_batch(self):
336338
if len(self.cached_train) != 0:
337339
h = self.cached_train[0]["image"].shape[0]
338340
w = self.cached_train[0]["image"].shape[1]
339-
return h * w
341+
return h * w * self.config.batch_size
340342
else:
341343
return 800 * 800
342344

343345
def next_train(self, step: int) -> Tuple[Cameras, Dict]:
344346
"""Returns the next training batch
345347
346348
Returns a Camera instead of raybundle"""
347-
image_idx = self.train_unseen_cameras.pop(0)
348-
# Make sure to re-populate the unseen cameras list if we have exhausted it
349-
if len(self.train_unseen_cameras) == 0:
350-
self.train_unseen_cameras = self.sample_train_cameras()
351349

352-
data = self.cached_train[image_idx]
353-
# We're going to copy to make sure we don't mutate the cached dictionary.
354-
# This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335
355-
data = data.copy()
356-
data["image"] = data["image"].to(self.device)
357-
358-
assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension"
359-
camera = self.train_cameras[image_idx : image_idx + 1].to(self.device)
360-
if camera.metadata is None:
361-
camera.metadata = {}
362-
camera.metadata["cam_idx"] = image_idx
363-
return camera, data
350+
image_indices = []
351+
for _ in range(self.config.batch_size):
352+
# Make sure to re-populate the unseen cameras list if we have exhausted it
353+
if len(self.train_unseen_cameras) == 0:
354+
self.train_unseen_cameras = self.sample_train_cameras()
355+
image_indices.append(self.train_unseen_cameras.pop(0))
356+
357+
all_keys = self.cached_train[0].keys()
358+
359+
data = {}
360+
for key in all_keys:
361+
if key == "image":
362+
data[key] = torch.stack([self.cached_train[i][key] for i in image_indices]).to(self.device)
363+
else:
364+
data[key] = [self.cached_train[i][key] for i in image_indices]
365+
366+
cameras = Cameras(
367+
camera_to_worlds=self.train_cameras.camera_to_worlds[image_indices],
368+
fx=self.train_cameras.fx[image_indices],
369+
fy=self.train_cameras.fy[image_indices],
370+
cx=self.train_cameras.cx[image_indices],
371+
cy=self.train_cameras.cy[image_indices],
372+
width=self.train_cameras.width[image_indices],
373+
height=self.train_cameras.height[image_indices],
374+
camera_type=self.train_cameras.camera_type[image_indices],
375+
).to(self.device)
376+
377+
if self.train_cameras.distortion_params is not None:
378+
cameras.distortion_params = self.train_cameras.distortion_params[image_indices]
379+
380+
if cameras.metadata is None:
381+
cameras.metadata = {}
382+
383+
cameras.metadata["cam_idx"] = image_indices
384+
385+
return cameras, data
364386

365387
def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
366388
"""Returns the next evaluation batch

nerfstudio/models/splatfacto.py

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,28 @@
4646
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases
4747

4848

49-
def resize_image(image: torch.Tensor, d: int):
49+
def resize_image(image: torch.Tensor, d: int) -> torch.Tensor:
5050
"""
5151
Downscale images using the same 'area' method in opencv
5252
53-
:param image shape [H, W, C]
53+
:param image shape [B, H, W, C]
5454
:param d downscale factor (must be 2, 4, 8, etc.)
5555
56-
return downscaled image in shape [H//d, W//d, C]
56+
return downscaled image in shape [B, H//d, W//d, C]
5757
"""
5858
import torch.nn.functional as tf
5959

60-
image = image.to(torch.float32)
6160
weight = (1.0 / (d * d)) * torch.ones((1, 1, d, d), dtype=torch.float32, device=image.device)
62-
return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)
61+
62+
B, H, W, C = image.shape
63+
image = image.permute(0, 3, 1, 2) # [B, C, H, W]
64+
image = image.reshape(B * C, 1, H, W) # Combine batch and channel dimensions for Conv2D
65+
66+
downscaled = tf.conv2d(image, weight, stride=d)
67+
downscaled = downscaled.reshape(B, C, downscaled.shape[-2], downscaled.shape[-1])
68+
downscaled = downscaled.permute(0, 2, 3, 1) # [B, H//d, W//d, C]
69+
70+
return downscaled
6371

6472

6573
@torch_compile()
@@ -482,32 +490,31 @@ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int)
482490
)
483491
return out["rgb"]
484492

485-
def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
486-
"""Takes in a camera and returns a dictionary of outputs.
493+
def get_outputs(self, cameras: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
494+
"""Takes in cameras and returns a dictionary of outputs.
487495
488496
Args:
489-
camera: The camera(s) for which output images are rendered. It should have
497+
cameras: The camera(s) for which output images are rendered. It should have
490498
all the needed information to compute the outputs.
491499
492500
Returns:
493501
Outputs of model. (ie. rendered colors)
494502
"""
495-
if not isinstance(camera, Cameras):
503+
if not isinstance(cameras, Cameras):
496504
print("Called get_outputs with not a camera")
497505
return {}
498506

499507
if self.training:
500-
assert camera.shape[0] == 1, "Only one camera at a time"
501-
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)
508+
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(cameras)
502509
else:
503-
optimized_camera_to_world = camera.camera_to_worlds
510+
optimized_camera_to_world = cameras.camera_to_worlds
504511

505512
# cropping
506513
if self.crop_box is not None and not self.training:
507514
crop_ids = self.crop_box.within(self.means).squeeze()
508515
if crop_ids.sum() == 0:
509516
return self.get_empty_outputs(
510-
int(camera.width.item()), int(camera.height.item()), self.background_color
517+
int(cameras.width.item()), int(cameras.height.item()), self.background_color
511518
)
512519
else:
513520
crop_ids = None
@@ -530,12 +537,16 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
530537
colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)
531538

532539
camera_scale_fac = self._get_downscale_factor()
533-
camera.rescale_output_resolution(1 / camera_scale_fac)
534-
viewmat = get_viewmat(optimized_camera_to_world)
535-
K = camera.get_intrinsics_matrices().cuda()
536-
W, H = int(camera.width.item()), int(camera.height.item())
540+
cameras.rescale_output_resolution(1 / camera_scale_fac)
541+
viewmats = get_viewmat(optimized_camera_to_world)
542+
Ks = cameras.get_intrinsics_matrices().cuda()
543+
544+
W, H = (
545+
int(cameras.width[0]),
546+
int(cameras.height[0]),
547+
) # assume all cameras have the same resolution
537548
self.last_size = (H, W)
538-
camera.rescale_output_resolution(camera_scale_fac) # type: ignore
549+
cameras.rescale_output_resolution(camera_scale_fac) # type: ignore
539550

540551
# apply the compensation of screen space blurring to gaussians
541552
if self.config.rasterize_mode not in ["antialiased", "classic"]:
@@ -558,8 +569,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
558569
scales=torch.exp(scales_crop),
559570
opacities=torch.sigmoid(opacities_crop).squeeze(-1),
560571
colors=colors_crop,
561-
viewmats=viewmat, # [1, 4, 4]
562-
Ks=K, # [1, 3, 3]
572+
viewmats=viewmats, # [1, 4, 4]
573+
Ks=Ks, # [1, 3, 3]
563574
width=W,
564575
height=H,
565576
packed=False,
@@ -585,24 +596,28 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
585596

586597
# apply bilateral grid
587598
if self.config.use_bilateral_grid and self.training:
588-
if camera.metadata is not None and "cam_idx" in camera.metadata:
589-
rgb = self._apply_bilateral_grid(rgb, camera.metadata["cam_idx"], H, W)
599+
if cameras.metadata is not None and "cam_idx" in cameras.metadata:
600+
rgb = self._apply_bilateral_grid(rgb, cameras.metadata["cam_idx"], H, W)
590601

591602
if render_mode == "RGB+ED":
592603
depth_im = render[:, ..., 3:4]
593-
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0)
604+
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max())
594605
else:
595606
depth_im = None
596607

597608
if background.shape[0] == 3 and not self.training:
598609
background = background.expand(H, W, 3)
599610

600-
return {
601-
"rgb": rgb.squeeze(0), # type: ignore
602-
"depth": depth_im, # type: ignore
603-
"accumulation": alpha.squeeze(0), # type: ignore
604-
"background": background, # type: ignore
605-
} # type: ignore
611+
outputs = {
612+
"rgb": rgb,
613+
"depth": depth_im,
614+
"accumulation": alpha,
615+
"background": background,
616+
}
617+
618+
if self.training:
619+
return outputs
620+
return {k: v.squeeze(0) if k != "background" else v for k, v in outputs.items()}
606621

607622
def get_gt_img(self, image: torch.Tensor):
608623
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
@@ -622,7 +637,7 @@ def composite_with_background(self, image, background) -> torch.Tensor:
622637
image: the image to composite
623638
background: the background color
624639
"""
625-
if image.shape[2] == 4:
640+
if image.shape[-1] == 4:
626641
alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3))
627642
return alpha * image[..., :3] + (1 - alpha) * background
628643
else:
@@ -671,7 +686,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
671686
pred_img = pred_img * mask
672687

673688
Ll1 = torch.abs(gt_img - pred_img).mean()
674-
simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...])
689+
simloss = 1 - self.ssim(gt_img.permute(0, 3, 1, 2), pred_img.permute(0, 3, 1, 2))
675690
if self.config.use_scale_regularization and self.step % 10 == 0:
676691
scale_exp = torch.exp(self.scales)
677692
scale_reg = (

0 commit comments

Comments
 (0)