46
46
from nerfstudio .utils .spherical_harmonics import RGB2SH , SH2RGB , num_sh_bases
47
47
48
48
49
- def resize_image (image : torch .Tensor , d : int ):
49
+ def resize_image (image : torch .Tensor , d : int ) -> torch . Tensor :
50
50
"""
51
51
Downscale images using the same 'area' method in opencv
52
52
53
- :param image shape [H, W, C]
53
+ :param image shape [B, H, W, C]
54
54
:param d downscale factor (must be 2, 4, 8, etc.)
55
55
56
- return downscaled image in shape [H//d, W//d, C]
56
+ return downscaled image in shape [B, H//d, W//d, C]
57
57
"""
58
58
import torch .nn .functional as tf
59
59
60
- image = image .to (torch .float32 )
61
60
weight = (1.0 / (d * d )) * torch .ones ((1 , 1 , d , d ), dtype = torch .float32 , device = image .device )
62
- return tf .conv2d (image .permute (2 , 0 , 1 )[:, None , ...], weight , stride = d ).squeeze (1 ).permute (1 , 2 , 0 )
61
+
62
+ B , H , W , C = image .shape
63
+ image = image .permute (0 , 3 , 1 , 2 ) # [B, C, H, W]
64
+ image = image .reshape (B * C , 1 , H , W ) # Combine batch and channel dimensions for Conv2D
65
+
66
+ downscaled = tf .conv2d (image , weight , stride = d )
67
+ downscaled = downscaled .reshape (B , C , downscaled .shape [- 2 ], downscaled .shape [- 1 ])
68
+ downscaled = downscaled .permute (0 , 2 , 3 , 1 ) # [B, H//d, W//d, C]
69
+
70
+ return downscaled
63
71
64
72
65
73
@torch_compile ()
@@ -482,32 +490,31 @@ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int)
482
490
)
483
491
return out ["rgb" ]
484
492
485
- def get_outputs (self , camera : Cameras ) -> Dict [str , Union [torch .Tensor , List ]]:
486
- """Takes in a camera and returns a dictionary of outputs.
493
+ def get_outputs (self , cameras : Cameras ) -> Dict [str , Union [torch .Tensor , List ]]:
494
+ """Takes in cameras and returns a dictionary of outputs.
487
495
488
496
Args:
489
- camera : The camera(s) for which output images are rendered. It should have
497
+ cameras : The camera(s) for which output images are rendered. It should have
490
498
all the needed information to compute the outputs.
491
499
492
500
Returns:
493
501
Outputs of model. (ie. rendered colors)
494
502
"""
495
- if not isinstance (camera , Cameras ):
503
+ if not isinstance (cameras , Cameras ):
496
504
print ("Called get_outputs with not a camera" )
497
505
return {}
498
506
499
507
if self .training :
500
- assert camera .shape [0 ] == 1 , "Only one camera at a time"
501
- optimized_camera_to_world = self .camera_optimizer .apply_to_camera (camera )
508
+ optimized_camera_to_world = self .camera_optimizer .apply_to_camera (cameras )
502
509
else :
503
- optimized_camera_to_world = camera .camera_to_worlds
510
+ optimized_camera_to_world = cameras .camera_to_worlds
504
511
505
512
# cropping
506
513
if self .crop_box is not None and not self .training :
507
514
crop_ids = self .crop_box .within (self .means ).squeeze ()
508
515
if crop_ids .sum () == 0 :
509
516
return self .get_empty_outputs (
510
- int (camera .width .item ()), int (camera .height .item ()), self .background_color
517
+ int (cameras .width .item ()), int (cameras .height .item ()), self .background_color
511
518
)
512
519
else :
513
520
crop_ids = None
@@ -530,12 +537,16 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
530
537
colors_crop = torch .cat ((features_dc_crop [:, None , :], features_rest_crop ), dim = 1 )
531
538
532
539
camera_scale_fac = self ._get_downscale_factor ()
533
- camera .rescale_output_resolution (1 / camera_scale_fac )
534
- viewmat = get_viewmat (optimized_camera_to_world )
535
- K = camera .get_intrinsics_matrices ().cuda ()
536
- W , H = int (camera .width .item ()), int (camera .height .item ())
540
+ cameras .rescale_output_resolution (1 / camera_scale_fac )
541
+ viewmats = get_viewmat (optimized_camera_to_world )
542
+ Ks = cameras .get_intrinsics_matrices ().cuda ()
543
+
544
+ W , H = (
545
+ int (cameras .width [0 ]),
546
+ int (cameras .height [0 ]),
547
+ ) # assume all cameras have the same resolution
537
548
self .last_size = (H , W )
538
- camera .rescale_output_resolution (camera_scale_fac ) # type: ignore
549
+ cameras .rescale_output_resolution (camera_scale_fac ) # type: ignore
539
550
540
551
# apply the compensation of screen space blurring to gaussians
541
552
if self .config .rasterize_mode not in ["antialiased" , "classic" ]:
@@ -558,8 +569,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
558
569
scales = torch .exp (scales_crop ),
559
570
opacities = torch .sigmoid (opacities_crop ).squeeze (- 1 ),
560
571
colors = colors_crop ,
561
- viewmats = viewmat , # [1, 4, 4]
562
- Ks = K , # [1, 3, 3]
572
+ viewmats = viewmats , # [1, 4, 4]
573
+ Ks = Ks , # [1, 3, 3]
563
574
width = W ,
564
575
height = H ,
565
576
packed = False ,
@@ -585,24 +596,28 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
585
596
586
597
# apply bilateral grid
587
598
if self .config .use_bilateral_grid and self .training :
588
- if camera .metadata is not None and "cam_idx" in camera .metadata :
589
- rgb = self ._apply_bilateral_grid (rgb , camera .metadata ["cam_idx" ], H , W )
599
+ if cameras .metadata is not None and "cam_idx" in cameras .metadata :
600
+ rgb = self ._apply_bilateral_grid (rgb , cameras .metadata ["cam_idx" ], H , W )
590
601
591
602
if render_mode == "RGB+ED" :
592
603
depth_im = render [:, ..., 3 :4 ]
593
- depth_im = torch .where (alpha > 0 , depth_im , depth_im .detach ().max ()). squeeze ( 0 )
604
+ depth_im = torch .where (alpha > 0 , depth_im , depth_im .detach ().max ())
594
605
else :
595
606
depth_im = None
596
607
597
608
if background .shape [0 ] == 3 and not self .training :
598
609
background = background .expand (H , W , 3 )
599
610
600
- return {
601
- "rgb" : rgb .squeeze (0 ), # type: ignore
602
- "depth" : depth_im , # type: ignore
603
- "accumulation" : alpha .squeeze (0 ), # type: ignore
604
- "background" : background , # type: ignore
605
- } # type: ignore
611
+ outputs = {
612
+ "rgb" : rgb ,
613
+ "depth" : depth_im ,
614
+ "accumulation" : alpha ,
615
+ "background" : background ,
616
+ }
617
+
618
+ if self .training :
619
+ return outputs
620
+ return {k : v .squeeze (0 ) if k != "background" else v for k , v in outputs .items ()}
606
621
607
622
def get_gt_img (self , image : torch .Tensor ):
608
623
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
@@ -622,7 +637,7 @@ def composite_with_background(self, image, background) -> torch.Tensor:
622
637
image: the image to composite
623
638
background: the background color
624
639
"""
625
- if image .shape [2 ] == 4 :
640
+ if image .shape [- 1 ] == 4 :
626
641
alpha = image [..., - 1 ].unsqueeze (- 1 ).repeat ((1 , 1 , 3 ))
627
642
return alpha * image [..., :3 ] + (1 - alpha ) * background
628
643
else :
@@ -671,7 +686,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
671
686
pred_img = pred_img * mask
672
687
673
688
Ll1 = torch .abs (gt_img - pred_img ).mean ()
674
- simloss = 1 - self .ssim (gt_img .permute (2 , 0 , 1 )[ None , ...] , pred_img .permute (2 , 0 , 1 )[ None , ...] )
689
+ simloss = 1 - self .ssim (gt_img .permute (0 , 3 , 1 , 2 ) , pred_img .permute (0 , 3 , 1 , 2 ) )
675
690
if self .config .use_scale_regularization and self .step % 10 == 0 :
676
691
scale_exp = torch .exp (self .scales )
677
692
scale_reg = (
0 commit comments