diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..7d626da --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM ghcr.io/nerfstudio-project/nerfstudio + +# Install system dependencies +RUN apt-get update && \ + apt-get install -y git gcc-10 g++-10 nvidia-cuda-toolkit ninja-build python3-pip wget && \ + rm -rf /var/lib/apt/lists/* + +# Set environment variables to use gcc-10 +ENV CC=gcc-10 +ENV CXX=g++-10 +ENV CUDA_HOME="/usr/local/cuda" +ENV CMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.utils.cmake_prefix_path)')" +ENV TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6+PTX" + +# Clone QED-Splatter and install in editable mode +RUN git clone https://github.com/leggedrobotics/qed-splatter.git /apps/qed-splatter +RUN pip install /apps/qed-splatter + +# Register QED-Splatter with nerfstudio +RUN ns-install-cli +RUN mkdir -p /usr/local/cuda/bin && ln -s /usr/bin/nvcc /usr/local/cuda/bin/nvcc + +# Install additional Python dependencies +RUN pip install git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e +RUN pip install git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157 +RUN pip install nerfview pyntcloud + +# Pre-download AlexNet pretrained weights to prevent runtime downloading +RUN mkdir -p /root/.cache/torch/hub/checkpoints && \ + wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth \ + -O /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth + +# Set working directory +WORKDIR /workspace diff --git a/README.md b/README.md index d72cafe..b862c65 100644 --- a/README.md +++ b/README.md @@ -32,3 +32,121 @@ To train the new method, use the following command: ``` ns-train qed-splatter --data [PATH] ``` + +## Pruning Extension + +The pruning extension provides tools to reduce the number of Gaussians in order to improve rendering speed. There are two main types of pruners available: + +- **Soft pruners** gradually reduce the number of Gaussians during training. +- **Hard pruners** are post-processing tools applied after training is complete. + +Each pruner computes a *pruning score* to evaluate the importance of individual Gaussians. The least important Gaussians are then removed. + +Currently, two hard pruning scripts are available: `rgb_hard_pruner` and `depth_hard_pruner`. + +### RGB_hard_pruner +This pruner uses RGB loss to compute a pruning score to do hard pruning. +``` +python3 RGB_hard_pruner.py default --data-dir datasets/park --ckpt results/park/step-000029999.ckpt --pruning-ratio 0.1 --result-dir output + +--eval-only (only evaluates, no saving, no pruning) +--pruning-ratio 0.0 (no pruning, saved in new format) +--output-format (ply (default), ckpt (nerfstudio), pt (gsplat)) +``` + +## πŸ“₯ Required Arguments + +| Argument | Description | +|-------------------|-----------------------------------------------------------------------------| +| `default` | Specifies the run configuration. | +| `--data-dir` | Path to the directory containing `transforms.json` (camera poses and intrinsics) and images (RGB, Depth)Β· | +| `--ckpt` | Path to the pretrained model checkpoint (e.g., `results/park/step-XXXXX.ckpt`). | +| `--pruning-ratio` | Float between `0.0` and `1.0`. Proportion of the model to prune. Example: `0.1` = keep 90%. | +| `--result-dir` | Directory where the output (pruned model) will be saved. | + + +## Input Format + +The code supports multiple output formats. The format is detected automatically. +- `ply` : expects a Nerfstudio format for the transforms. +- `ckpt` : expects a Nerfstudio format for the transforms. +- `pt` : expects a gsplat format for the transforms. + + + +### GSPlat Dataset Format + +This repository expects datasets to be structured in a COLMAP-like format, which includes camera parameters, image poses, and optionally 3D points. This format is commonly used for 3D reconstruction and novel view synthesis tasks. + +#### πŸ“ Folder Structure + +Your dataset should be organized like this: +``` +data_dir/ +β”œβ”€β”€ images/ # All input images +β”‚ β”œβ”€β”€ img1.jpg +β”‚ β”œβ”€β”€ img2.png +β”‚ └── ... +β”œβ”€β”€ sparse/ # Sparse reconstruction data (from COLMAP) +β”‚ β”œβ”€β”€ cameras.bin # Camera intrinsics +β”‚ β”œβ”€β”€ images.bin # Image poses (extrinsic) and filenames +β”‚ └── points3D.bin # Optional: 3D point cloud +``` + +### Nerfstudio Dataset Format + +#### πŸ”§ `transforms.json` File + +This file must include the following: + +- Intrinsic camera parameters: + - `"fl_x"`, `"fl_y"`: focal lengths + - `"cx"`, `"cy"`: principal point + - `"w"`, `"h"`: image dimensions + +- A list of frames, each containing: + - `file_path`: path to the RGB image (relative to `your_dataset/`) + - `depth_file_path`: path to the depth map (relative to `your_dataset/`) + - `transform_matrix`: 4x4 camera-to-world matrix + +**Example:** +```json +{ + "w": 1920, + "h": 1080, + "fl_x": 2198.997802734375, + "fl_y": 2198.997802734375, + "cx": 960.0, + "cy": 540.0, + "k1": 0, + "k2": 0, + "p1": 0, + "p2": 0, + "frames": [ + { + "file_path": "images/frame_0000.png", + "depth_file_path": "depths/frame_0000.png", + "transform_matrix": [[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]] + }, + { + "file_path": "images/frame_0001.png", + "depth_file_path": "depths/frame_0001.png", + "transform_matrix": [[0,2,0,0], [0,1,3,0], [0,0,1,0], [0,5,0,1]] + } + ] +} +``` + + +### Depth_hard_pruner +This pruner uses depth loss to compute a pruning score to do hard pruning. It works analogously to the RGB hard pruner but not all features are available. +``` +python3 depth_hard_pruner.py default --data-dir datasets/park --ckpt results/park/step-000029999.ckpt --pruning-ratio 0.1 --result-dir output + +--eval-only (only evaluates, no saving, no pruning) +--pruning-ratio 0.0 (no pruning, saved in new format) +--output-format (ply (default), ckpt (nerfstudio), pt (gsplat)) +``` + +#### Known Issues +For the Park scene it tries to generate black gaussians to cover the sky. The enitre scene is encased in these gaussians. diff --git a/RGB_hard_pruner.py b/RGB_hard_pruner.py new file mode 100644 index 0000000..e5efafb --- /dev/null +++ b/RGB_hard_pruner.py @@ -0,0 +1,887 @@ +""" +RGB hard pruning Script + +Part of this code is based on gsplat’s `simple_trainer.py`: +https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py + +""" +import math +import os +import time +from dataclasses import dataclass, field +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +from pruning_utils.nerf import NerfDataset , NerfParser +from pruning_utils.colmap import Dataset, Parser +from pruning_utils.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal, assert_never +from pruning_utils.utils import AppearanceOptModule, CameraOptModule, set_random_seed +from pruning_utils.lib_bilagrid import ( + BilateralGrid, + color_correct, +) +from pruning_utils.open_ply_pipeline import load_splats, save_splats + + +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat.rendering import rasterization +from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.strategy.ops import remove + + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Pruning Ratio + pruning_ratio: float = 0.0 + # Output data format converted + output_format: str = "ply" + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "./results" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + # Camera model + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 10 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Only run evalutation + eval_only: bool = False + + # Strategy for GS densification + strategy: Union[DefaultStrategy, MCMCStrategy] = field( + default_factory=DefaultStrategy + ) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.0 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + if isinstance(strategy, DefaultStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + elif isinstance(strategy, MCMCStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) + + + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank: int, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + ext = os.path.splitext(cfg.ckpt[0])[1].lower() + if ext == ".ckpt" or ext == ".ply" : + print("this is a ckpt or ply file ") + self.parser = NerfParser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = NerfDataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + elif ext == ".pt": + print("this is a pt file") + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + else: + msg ="Invalid Model type. Use .pt, .ckpt or .ply files." + raise TypeError(msg) + + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = self.create_splats_with_optimizers( + scene_scale=self.scene_scale, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + device=self.device, + world_size=world_size, + cfg=self.cfg, + ) + + print("Model initialized. Number of GS:", len(self.splats["means"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, DefaultStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.strategy_state = self.cfg.strategy.initialize_state() + else: + assert_never(self.cfg.strategy) + + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def create_splats_with_optimizers( + self, + scene_scale: float = 1.0, + sparse_grad: bool = False, + batch_size: int = 1, + device: str = "cuda", + world_size: int = 1, + cfg: Optional[List[str]] = None, + ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + + + + self.steps, splats = load_splats(cfg.ckpt[0], device) + + # Convert to ParameterDict on the correct device + splats = torch.nn.ParameterDict(splats).to(device) + + + # Learning rates: you need to define them since they’re not stored in the ckpt + # Use default values from above + default_lrs = { + "means": 1.6e-4 * scene_scale, + "scales": 5e-3, + "quats": 1e-3, + "opacities": 5e-2, + "sh0": 2.5e-3, + "shN": 2.5e-3 / 20, + "features": 2.5e-3, + "colors": 2.5e-3, + } + + BS = batch_size * world_size + optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [ + { + "params": splats[name], + "lr": default_lrs[name] * math.sqrt(BS), + "name": name, + } + ], + eps=1e-15 / math.sqrt(BS), + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name in splats.keys() + if name in default_lrs + } + + return splats, optimizers + + + def reinit_optimizers(self): + """Reinitialize optimizers after pruning Gaussians.""" + cfg = self.cfg + BS = cfg.batch_size * self.world_size + + # Recreate the optimizer dictionary with new parameters + new_optimizers = {} + for name, param in self.splats.items(): + lr = { + "means": 1.6e-4 * self.scene_scale, + "scales": 5e-3, + "quats": 1e-3, + "opacities": 5e-2, + "sh0": 2.5e-3, + "shN": 2.5e-3 / 20, + "features": 2.5e-3, + "colors": 2.5e-3, + }[name] + + betas = ( + 1 - BS * (1 - 0.9), + 1 - BS * (1 - 0.999), + ) + + new_optimizers[name] = torch.optim.Adam( + [{"params": param, "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + betas=betas, + ) + + # Replace old optimizers with new ones + self.optimizers = new_optimizers + + + def prune_gaussians(self, prune_ratio: float, scores: torch.Tensor): + """Prune Gaussians based on score thresholding.""" + num_prune = int(len(scores) * prune_ratio) + _, idx = torch.topk(scores.squeeze(), k=num_prune, largest=False) + mask = torch.ones_like(scores, dtype=torch.bool) + mask[idx] = False + + remove( + params=self.splats, + optimizers= self.optimizers, + state= self.strategy_state, + mask = ~mask, + ) + + + + + @torch.enable_grad() + def score_func( + self, + viewpoint_cam: Dict[str, torch.Tensor], + scores: torch.Tensor, + mask_views: torch.Tensor + ) -> None: + + # Get camera matrices without extra dimensions + camtoworld = viewpoint_cam["camtoworld"].to(self.device) # shape: [4, 4] + K = viewpoint_cam["K"].to(self.device) # shape: [3, 3] + height, width = viewpoint_cam["image"].shape[1:3] + + # Forward pass + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworld, # Add batch dim here only + Ks=K, # Add batch dim here only + width=width, + height=height, + sh_degree=self.cfg.sh_degree, + image_ids=viewpoint_cam["image_id"].to(self.device)[None], + ) + + # Compute loss + gt_image = viewpoint_cam["image"].to(self.device) / 255.0 + l1loss = F.l1_loss(colors, gt_image) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), + gt_image.permute(0, 3, 1, 2), + padding="valid" + ) + loss = l1loss * (1.0 - self.cfg.ssim_lambda) + ssimloss * self.cfg.ssim_lambda + + # Backward pass + loss.backward() + + + + # Opacity gradient + opacity_grad = self.splats["opacities"].grad.abs().squeeze() # [N] + + # Means gradient - reduce across channel dim (assumes shape [N, 3]) + means_grad = self.splats["means"].grad.abs().mean(dim=1).squeeze() # [N] + + # Scales gradient + scales_grad = self.splats["scales"].grad.abs().mean(dim=1).squeeze() # [N] + + # SH0 gradient - make sure you are reducing the right dimension + sh0_grad = self.splats["sh0"].grad.abs().view(self.splats["sh0"].grad.shape[0], -1).mean(dim=1).squeeze() + + # SHN gradient - often [N, K, 3], so reduce last two dims + shN_grad = self.splats["shN"].grad.abs().view(self.splats["shN"].grad.shape[0], -1).mean(dim=1).squeeze() # [N] + + + # Combine all scores + combined = opacity_grad + means_grad + scales_grad + sh0_grad + shN_grad + + # Thresholding + mask_views += combined > 50 * 1e-8 + + # Accumulate for scoring + with torch.no_grad(): + scores += combined + + + @torch.no_grad() + def prune(self, prune_ratio: float): + print("Running pruning...") + scores = torch.zeros_like(self.splats["opacities"]) + mask_views = torch.zeros_like(self.splats["opacities"]) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=1, + shuffle=False, + num_workers=1, + pin_memory=True, + ) + + pbar = tqdm.tqdm(trainloader, desc="Computing pruning scores") + for data in pbar: + self.score_func(data, scores, mask_views) + pbar.update(1) + + np.savetxt('mask_views_full.txt', mask_views.cpu().numpy(), fmt='%.18e') + np.savetxt('scores_txt.txt', scores.cpu().numpy(), fmt='%.18e') + + # Prune Gaussians + self.prune_gaussians(prune_ratio, scores) + + + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + masks: Optional[Tensor] = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, DefaultStrategy) + else False + ), + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + camera_model=self.cfg.camera_model, + **kwargs, + ) + if masks is not None: + render_colors[~masks] = 0 + return render_colors, render_alphas, info + + + @torch.no_grad() + def eval(self, step: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + masks = data["mask"].to(device) if "mask" in data else None + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + masks=masks, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, + ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + + + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + if cfg.ckpt is not None: + # run eval only + runner = Runner(local_rank, world_rank, world_size, cfg) + + step = runner.steps + if cfg.eval_only: + runner.eval(step=step) + else: + print("Hard pruning in progress...") + if cfg.pruning_ratio != 0: + runner.prune(prune_ratio=cfg.pruning_ratio) + print("The size of gaussian is:", runner.splats["means"].shape) + + # save checkpoint after hard pruning + name = os.path.splitext(os.path.basename(cfg.ckpt[0]))[0] + + data = {"step": step, "splats": runner.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = runner.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = runner.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = runner.app_module.module.state_dict() + else: + data["app_module"] = runner.app_module.state_dict() + suffix = str(cfg.pruning_ratio).replace("0.", "") + + print("output format", cfg.output_format) + + + save_splats(f"{runner.ckpt_dir}/{name}_pruned_test{suffix}", data, cfg.output_format) + + else: + raise ValueError("ckpt cant be None") + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + # Config objects we can choose between. + # Each is a tuple of (CLI description, config object). + configs = { + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + Config( + strategy=DefaultStrategy(verbose=True), + ), + ), + "mcmc": ( + "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", + Config( + init_opa=0.5, + init_scale=0.1, + opacity_reg=0.01, + scale_reg=0.01, + strategy=MCMCStrategy(verbose=True), + ), + ), + } + cfg = tyro.extras.overridable_config_cli(configs) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True) \ No newline at end of file diff --git a/depth_hard_pruner.py b/depth_hard_pruner.py new file mode 100644 index 0000000..bd3cb14 --- /dev/null +++ b/depth_hard_pruner.py @@ -0,0 +1,853 @@ +""" +Depth hard pruning Script + +Part of this code is based on gsplat’s `simple_trainer.py`: +https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py + +""" + + +import math +import os +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser + +from pruning_utils.nerf import NerfDataset , NerfParser +from pruning_utils.colmap import Dataset, Parser +from pruning_utils.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal, assert_never +from pruning_utils.utils import AppearanceOptModule, CameraOptModule, set_random_seed +from pruning_utils.lib_bilagrid import ( + BilateralGrid, +) +from pruning_utils.open_ply_pipeline import load_splats, save_splats + + +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat.rendering import rasterization +from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.strategy.ops import remove + +from PIL import Image + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Pruning Ratio + pruning_ratio: float = 0.0 + # Output data format converted + output_format: str = "ply" + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "./results" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = False + # Camera model + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 10 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: Union[DefaultStrategy, MCMCStrategy] = field( + default_factory=DefaultStrategy + ) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.0 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = True + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + if isinstance(strategy, DefaultStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + elif isinstance(strategy, MCMCStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) + + + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank: int, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + ext = os.path.splitext(cfg.ckpt[0])[1].lower() + if ext == ".ckpt" or ext == ".ply" : + print("this is a ckpt or ply file ") + self.parser = NerfParser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = NerfDataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + elif ext == ".pt": + print("this is a pt file") + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + else: + msg ="Invalid Model type. Use .pt, .ckpt or .ply files." + raise TypeError(msg) + + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = self.create_splats_with_optimizers( + scene_scale=self.scene_scale, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + device=self.device, + world_size=world_size, + cfg=self.cfg, + ) + print("Model initialized. Number of GS:", len(self.splats["means"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, DefaultStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.strategy_state = self.cfg.strategy.initialize_state() + else: + assert_never(self.cfg.strategy) + + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def create_splats_with_optimizers( + self, + scene_scale: float = 1.0, + sparse_grad: bool = False, + batch_size: int = 1, + device: str = "cuda", + world_size: int = 1, + cfg: Optional[List[str]] = None, + ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + + + self.steps, splats = load_splats(cfg.ckpt[0], device) + + # Convert to ParameterDict on the correct device + splats = torch.nn.ParameterDict(splats).to(device) + + + # Learning rates: you need to define them since they’re not stored in the ckpt + # Use default values from above + default_lrs = { + "means": 1.6e-4 * scene_scale, + "scales": 5e-3, + "quats": 1e-3, + "opacities": 5e-2, + "sh0": 2.5e-3, + "shN": 2.5e-3 / 20, + "features": 2.5e-3, + "colors": 2.5e-3, + } + + BS = batch_size * world_size + optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [ + { + "params": splats[name], + "lr": default_lrs[name] * math.sqrt(BS), + "name": name, + } + ], + eps=1e-15 / math.sqrt(BS), + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name in splats.keys() + if name in default_lrs + } + + return splats, optimizers + + + def reinit_optimizers(self): + """Reinitialize optimizers after pruning Gaussians.""" + cfg = self.cfg + BS = cfg.batch_size * self.world_size + + # Recreate the optimizer dictionary with new parameters + new_optimizers = {} + for name, param in self.splats.items(): + lr = { + "means": 1.6e-4 * self.scene_scale, + "scales": 5e-3, + "quats": 1e-3, + "opacities": 5e-2, + "sh0": 2.5e-3, + "shN": 2.5e-3 / 20, + "features": 2.5e-3, + "colors": 2.5e-3, + }[name] + + betas = ( + 1 - BS * (1 - 0.9), + 1 - BS * (1 - 0.999), + ) + + new_optimizers[name] = torch.optim.Adam( + [{"params": param, "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + betas=betas, + ) + + # Replace old optimizers with new ones + self.optimizers = new_optimizers + + + def prune_gaussians(self, prune_ratio: float, scores: torch.Tensor): + """Prune Gaussians based on score thresholding.""" + reduced_scores = scores.mean(dim=1) + + num_prune = int(len(reduced_scores) * prune_ratio) + _, idx = torch.topk(reduced_scores, k=num_prune, largest=False) + mask = torch.ones_like(reduced_scores, dtype=torch.bool) + mask[idx] = False + + remove( + params=self.splats, + optimizers= self.optimizers, + state= self.strategy_state, + mask = ~mask, + ) + + + + def save_tensors_side_by_side(self, rendered_tensor, target_tensor, filename="depth_debug/depth_comparison.png"): + """ + Saves two PyTorch tensors (e.g., depth maps) as side-by-side PNG images. + If the file exists, appends a number to avoid overwriting (e.g., _1, _2, ...). + + Args: + rendered_tensor (torch.Tensor): Predicted tensor (e.g., rendered depth). + target_tensor (torch.Tensor): Ground truth tensor (e.g., image depth). + filename (str): Base output filename (will be modified if it exists). + """ + def tensor_to_image(tensor): + # Ensure tensor is on CPU and detach gradients + tensor = tensor.cpu().detach() + + # Remove batch dimension if present (e.g., [1, H, W, 1] -> [H, W, 1]) + while tensor.dim() > 3 and tensor.shape[0] == 1: + tensor = tensor.squeeze(0) + + # At this point, should be [H, W] or [H, W, 1] + if tensor.dim() == 3 and tensor.shape[-1] == 1: + tensor = tensor.squeeze(-1) # Remove channel dim + + # Now it should be [H, W] + if tensor.dim() != 2: + raise ValueError(f"Unexpected tensor shape after squeezing: {tensor.shape}") + + # Normalize to [0, 1] + tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) + + # Convert to uint8 [0, 255] + image = (tensor.numpy() * 255).astype(np.uint8) + + return Image.fromarray(image, mode='L'), image # L = grayscale + + # Convert both tensors to PIL images + rendered_img, rendered_image = tensor_to_image(rendered_tensor) + target_img, target_image = tensor_to_image(target_tensor) + + + # Create a new image that combines them horizontally + total_width = rendered_img.width + target_img.width + max_height = max(rendered_img.height, target_img.height) + + combined_img = Image.new('L', (total_width, max_height)) # 'L' for grayscale + combined_img.paste(rendered_img, (0, 0)) + combined_img.paste(target_img, (rendered_img.width, 0)) + + # Save to disk + combined_img.save(filename) + + return rendered_image, target_image + + + @torch.enable_grad() + def score_func( + self, + viewpoint_cam: Dict[str, torch.Tensor], + scores: torch.Tensor, + mask_views: torch.Tensor + ) -> None: + + # Get camera matrices without extra dimensions + camtoworld = viewpoint_cam["camtoworld"].to(self.device) # shape: [4, 4] + K = viewpoint_cam["K"].to(self.device) # shape: [3, 3] + height, width = viewpoint_cam["image"].shape[1:3] + + # Forward pass + render, _, info = self.rasterize_splats( + camtoworlds=camtoworld, # Add batch dim here only + Ks=K, # Add batch dim here only + width=width, + height=height, + sh_degree=self.cfg.sh_degree, + image_ids=viewpoint_cam["image_id"].to(self.device)[None], + render_mode="RGB+ED" + ) + + # Compute loss + rendered_depth = render[..., 3:4] + + image_depth = self.trainset[viewpoint_cam["image_id"].item()]["depth"].to(self.device)[None].unsqueeze(-1) + + # rendered_image, target_image = self.save_tensors_side_by_side(rendered_np, image_np, f"depth_debug/00{self.trainset.indices[viewpoint_cam['image_id'].item()]}.png") + + l1loss = F.l1_loss(rendered_depth, image_depth) + ssimloss = 1.0 - fused_ssim( + rendered_depth.permute(0, 3, 1, 2), + image_depth.permute(0, 3, 1, 2), + padding="valid" + ) + loss = l1loss * (1.0 - self.cfg.ssim_lambda) + ssimloss * self.cfg.ssim_lambda + + # Backward pass + loss.backward() + + mask_views += self.splats["means"].grad.abs().squeeze() > 0.001 + + # Accumulate gradient magnitude as score (e.g., opacities) + with torch.no_grad(): + scores += self.splats["means"].grad.abs().squeeze() + + + @torch.no_grad() + def prune(self, prune_ratio: float): + print("Running pruning...") + scores = torch.zeros_like(self.splats["means"]) + mask_views = torch.zeros_like(self.splats["means"]) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=1, + shuffle=False, + num_workers=1, + pin_memory=True, + ) + + i = 0 + pbar = tqdm.tqdm(trainloader, desc="Computing pruning scores") + for data in pbar: + self.score_func(data, scores, mask_views) + pbar.update(1) + i += 1 + + scores = scores / (mask_views + 1e-8) + + self.prune_gaussians(prune_ratio, scores) + + + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + masks: Optional[Tensor] = None, + render_mode: Optional[str] = "RGB", + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, DefaultStrategy) + else False + ), + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + camera_model=self.cfg.camera_model, + render_mode=render_mode, + **kwargs, + ) + if masks is not None: + render_colors[~masks] = 0 + return render_colors, render_alphas, info + + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+D", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + if cfg.ckpt is not None: + # run eval only + runner = Runner(local_rank, world_rank, world_size, cfg) + + + if cfg.pruning_ratio != 0: + print("Hard pruning in progress...") + runner.prune(prune_ratio=cfg.pruning_ratio) + print("The size of gaussian is:", runner.splats["means"].shape) + + # save checkpoint after hard pruning + step = runner.steps + + name = os.path.splitext(os.path.basename(cfg.ckpt[0]))[0] + + data = {"step": step, "splats": runner.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = runner.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = runner.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = runner.app_module.module.state_dict() + else: + data["app_module"] = runner.app_module.state_dict() + suffix = str(cfg.pruning_ratio).replace("0.", "") + + print("output format", cfg.output_format) + + save_splats(f"{runner.ckpt_dir}/{name}_depth_pruned_{suffix}", data, cfg.output_format) + + else: + raise ValueError("ckpt cant be None") + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + # Config objects we can choose between. + # Each is a tuple of (CLI description, config object). + configs = { + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + Config( + strategy=DefaultStrategy(verbose=True), + ), + ), + "mcmc": ( + "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", + Config( + init_opa=0.5, + init_scale=0.1, + opacity_reg=0.01, + scale_reg=0.01, + strategy=MCMCStrategy(verbose=True), + ), + ), + } + cfg = tyro.extras.overridable_config_cli(configs) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True) \ No newline at end of file diff --git a/pruning_utils/colmap.py b/pruning_utils/colmap.py new file mode 100644 index 0000000..9bf0ccb --- /dev/null +++ b/pruning_utils/colmap.py @@ -0,0 +1,425 @@ +""" +=========================================================================== +Unmodified code from gsplat examples +Original source: +https://github.com/nerfstudio-project/gsplat/blob/main/examples/datasets/colmap.py :contentReference[oaicite:0]{index=0} +License: Apache License 2.0 +---------------------------------------------------------------------------- +This file was copied verbatim from the gsplat repository. No modifications were made. +=========================================================================== +""" + +import os +import json +from typing import Any, Dict, List, Optional +from typing_extensions import assert_never + +import cv2 +import imageio.v2 as imageio +import numpy as np +import torch +from pycolmap import SceneManager + +from .normalize import ( + align_principle_axes, + similarity_from_cameras, + transform_cameras, + transform_points, +) + + +def _get_rel_paths(path_dir: str) -> List[str]: + """Recursively get relative paths of files in a directory.""" + paths = [] + for dp, dn, fn in os.walk(path_dir): + for f in fn: + paths.append(os.path.relpath(os.path.join(dp, f), path_dir)) + return paths + + +class Parser: + """COLMAP parser.""" + + def __init__( + self, + data_dir: str, + factor: int = 1, + normalize: bool = False, + test_every: int = 8, + ): + self.data_dir = data_dir + self.factor = factor + self.normalize = normalize + self.test_every = test_every + + colmap_dir = os.path.join(data_dir, "sparse/0/") + if not os.path.exists(colmap_dir): + colmap_dir = os.path.join(data_dir, "sparse") + assert os.path.exists( + colmap_dir + ), f"COLMAP directory {colmap_dir} does not exist." + + manager = SceneManager(colmap_dir) + manager.load_cameras() + manager.load_images() + manager.load_points3D() + + # Extract extrinsic matrices in world-to-camera format. + imdata = manager.images + w2c_mats = [] + camera_ids = [] + Ks_dict = dict() + params_dict = dict() + imsize_dict = dict() # width, height + mask_dict = dict() + bottom = np.array([0, 0, 0, 1]).reshape(1, 4) + for k in imdata: + im = imdata[k] + rot = im.R() + trans = im.tvec.reshape(3, 1) + w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) + w2c_mats.append(w2c) + + # support different camera intrinsics + camera_id = im.camera_id + camera_ids.append(camera_id) + + # camera intrinsics + cam = manager.cameras[camera_id] + fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + K[:2, :] /= factor + Ks_dict[camera_id] = K + + # Get distortion parameters. + type_ = cam.camera_type + if type_ == 0 or type_ == "SIMPLE_PINHOLE": + params = np.empty(0, dtype=np.float32) + camtype = "perspective" + elif type_ == 1 or type_ == "PINHOLE": + params = np.empty(0, dtype=np.float32) + camtype = "perspective" + if type_ == 2 or type_ == "SIMPLE_RADIAL": + params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32) + camtype = "perspective" + elif type_ == 3 or type_ == "RADIAL": + params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32) + camtype = "perspective" + elif type_ == 4 or type_ == "OPENCV": + params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32) + camtype = "perspective" + elif type_ == 5 or type_ == "OPENCV_FISHEYE": + params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32) + camtype = "fisheye" + assert ( + camtype == "perspective" or camtype == "fisheye" + ), f"Only perspective and fisheye cameras are supported, got {type_}" + + params_dict[camera_id] = params + imsize_dict[camera_id] = (cam.width // factor, cam.height // factor) + mask_dict[camera_id] = None + print( + f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras." + ) + + if len(imdata) == 0: + raise ValueError("No images found in COLMAP.") + if not (type_ == 0 or type_ == 1): + print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") + + w2c_mats = np.stack(w2c_mats, axis=0) + + # Convert extrinsics to camera-to-world. + camtoworlds = np.linalg.inv(w2c_mats) + + # Image names from COLMAP. No need for permuting the poses according to + # image names anymore. + image_names = [imdata[k].name for k in imdata] + + # Previous Nerf results were generated with images sorted by filename, + # ensure metrics are reported on the same test set. + inds = np.argsort(image_names) + image_names = [image_names[i] for i in inds] + camtoworlds = camtoworlds[inds] + camera_ids = [camera_ids[i] for i in inds] + + # Load extended metadata. Used by Bilarf dataset. + self.extconf = { + "spiral_radius_scale": 1.0, + "no_factor_suffix": False, + } + extconf_file = os.path.join(data_dir, "ext_metadata.json") + if os.path.exists(extconf_file): + with open(extconf_file) as f: + self.extconf.update(json.load(f)) + + # Load bounds if possible (only used in forward facing scenes). + self.bounds = np.array([0.01, 1.0]) + posefile = os.path.join(data_dir, "poses_bounds.npy") + if os.path.exists(posefile): + self.bounds = np.load(posefile)[:, -2:] + + # Load images. + if factor > 1 and not self.extconf["no_factor_suffix"]: + image_dir_suffix = f"_{factor}" + else: + image_dir_suffix = "" + colmap_image_dir = os.path.join(data_dir, "images") + image_dir = os.path.join(data_dir, "images" + image_dir_suffix) + for d in [image_dir, colmap_image_dir]: + if not os.path.exists(d): + raise ValueError(f"Image folder {d} does not exist.") + + # Downsampled images may have different names vs images used for COLMAP, + # so we need to map between the two sorted lists of files. + colmap_files = sorted(_get_rel_paths(colmap_image_dir)) + image_files = sorted(_get_rel_paths(image_dir)) + colmap_to_image = dict(zip(colmap_files, image_files)) + image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names] + + # 3D points and {image_name -> [point_idx]} + points = manager.points3D.astype(np.float32) + points_err = manager.point3D_errors.astype(np.float32) + points_rgb = manager.point3D_colors.astype(np.uint8) + point_indices = dict() + + image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()} + for point_id, data in manager.point3D_id_to_images.items(): + for image_id, _ in data: + image_name = image_id_to_name[image_id] + point_idx = manager.point3D_id_to_point3D_idx[point_id] + point_indices.setdefault(image_name, []).append(point_idx) + point_indices = { + k: np.array(v).astype(np.int32) for k, v in point_indices.items() + } + + # Normalize the world space. + if normalize: + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + points = transform_points(T1, points) + + T2 = align_principle_axes(points) + camtoworlds = transform_cameras(T2, camtoworlds) + points = transform_points(T2, points) + + transform = T2 @ T1 + else: + transform = np.eye(4) + + self.image_names = image_names # List[str], (num_images,) + self.image_paths = image_paths # List[str], (num_images,) + self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.camera_ids = camera_ids # List[int], (num_images,) + self.Ks_dict = Ks_dict # Dict of camera_id -> K + self.params_dict = params_dict # Dict of camera_id -> params + self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) + self.mask_dict = mask_dict # Dict of camera_id -> mask + self.points = points # np.ndarray, (num_points, 3) + self.points_err = points_err # np.ndarray, (num_points,) + self.points_rgb = points_rgb # np.ndarray, (num_points, 3) + self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] + self.transform = transform # np.ndarray, (4, 4) + + # load one image to check the size. In the case of tanksandtemples dataset, the + # intrinsics stored in COLMAP corresponds to 2x upsampled images. + actual_image = imageio.imread(self.image_paths[0])[..., :3] + actual_height, actual_width = actual_image.shape[:2] + colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]] + s_height, s_width = actual_height / colmap_height, actual_width / colmap_width + for camera_id, K in self.Ks_dict.items(): + K[0, :] *= s_width + K[1, :] *= s_height + self.Ks_dict[camera_id] = K + width, height = self.imsize_dict[camera_id] + self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) + + # undistortion + self.mapx_dict = dict() + self.mapy_dict = dict() + self.roi_undist_dict = dict() + for camera_id in self.params_dict.keys(): + params = self.params_dict[camera_id] + if len(params) == 0: + continue # no distortion + assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" + assert ( + camera_id in self.params_dict + ), f"Missing params for camera {camera_id}" + K = self.Ks_dict[camera_id] + width, height = self.imsize_dict[camera_id] + + if camtype == "perspective": + K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( + K, params, (width, height), 0 + ) + mapx, mapy = cv2.initUndistortRectifyMap( + K, params, None, K_undist, (width, height), cv2.CV_32FC1 + ) + mask = None + elif camtype == "fisheye": + fx = K[0, 0] + fy = K[1, 1] + cx = K[0, 2] + cy = K[1, 2] + grid_x, grid_y = np.meshgrid( + np.arange(width, dtype=np.float32), + np.arange(height, dtype=np.float32), + indexing="xy", + ) + x1 = (grid_x - cx) / fx + y1 = (grid_y - cy) / fy + theta = np.sqrt(x1**2 + y1**2) + r = ( + 1.0 + + params[0] * theta**2 + + params[1] * theta**4 + + params[2] * theta**6 + + params[3] * theta**8 + ) + mapx = fx * x1 * r + width // 2 + mapy = fy * y1 * r + height // 2 + + # Use mask to define ROI + mask = np.logical_and( + np.logical_and(mapx > 0, mapy > 0), + np.logical_and(mapx < width - 1, mapy < height - 1), + ) + y_indices, x_indices = np.nonzero(mask) + y_min, y_max = y_indices.min(), y_indices.max() + 1 + x_min, x_max = x_indices.min(), x_indices.max() + 1 + mask = mask[y_min:y_max, x_min:x_max] + K_undist = K.copy() + K_undist[0, 2] -= x_min + K_undist[1, 2] -= y_min + roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min] + else: + assert_never(camtype) + + self.mapx_dict[camera_id] = mapx + self.mapy_dict[camera_id] = mapy + self.Ks_dict[camera_id] = K_undist + self.roi_undist_dict[camera_id] = roi_undist + self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3]) + self.mask_dict[camera_id] = mask + + # size of the scene measured by cameras + camera_locations = camtoworlds[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + self.scene_scale = np.max(dists) + + +class Dataset: + """A simple dataset class.""" + + def __init__( + self, + parser: Parser, + split: str = "train", + patch_size: Optional[int] = None, + load_depths: bool = False, + ): + self.parser = parser + self.split = split + self.patch_size = patch_size + self.load_depths = load_depths + indices = np.arange(len(self.parser.image_names)) + if split == "train": + self.indices = indices[indices % self.parser.test_every != 0] + else: + self.indices = indices[indices % self.parser.test_every == 0] + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item: int) -> Dict[str, Any]: + index = self.indices[item] + image = imageio.imread(self.parser.image_paths[index])[..., :3] + camera_id = self.parser.camera_ids[index] + K = self.parser.Ks_dict[camera_id].copy() # undistorted K + params = self.parser.params_dict[camera_id] + camtoworlds = self.parser.camtoworlds[index] + mask = self.parser.mask_dict[camera_id] + + if len(params) > 0: + # Images are distorted. Undistort them. + mapx, mapy = ( + self.parser.mapx_dict[camera_id], + self.parser.mapy_dict[camera_id], + ) + image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) + x, y, w, h = self.parser.roi_undist_dict[camera_id] + image = image[y : y + h, x : x + w] + + if self.patch_size is not None: + # Random crop. + h, w = image.shape[:2] + x = np.random.randint(0, max(w - self.patch_size, 1)) + y = np.random.randint(0, max(h - self.patch_size, 1)) + image = image[y : y + self.patch_size, x : x + self.patch_size] + K[0, 2] -= x + K[1, 2] -= y + + data = { + "K": torch.from_numpy(K).float(), + "camtoworld": torch.from_numpy(camtoworlds).float(), + "image": torch.from_numpy(image).float(), + "image_id": item, # the index of the image in the dataset + } + if mask is not None: + data["mask"] = torch.from_numpy(mask).bool() + + if self.load_depths: + # projected points to image plane to get depths + worldtocams = np.linalg.inv(camtoworlds) + image_name = self.parser.image_names[index] + point_indices = self.parser.point_indices[image_name] + points_world = self.parser.points[point_indices] + points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T + points_proj = (K @ points_cam.T).T + points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2) + depths = points_cam[:, 2] # (M,) + # filter out points outside the image + selector = ( + (points[:, 0] >= 0) + & (points[:, 0] < image.shape[1]) + & (points[:, 1] >= 0) + & (points[:, 1] < image.shape[0]) + & (depths > 0) + ) + points = points[selector] + depths = depths[selector] + data["points"] = torch.from_numpy(points).float() + data["depths"] = torch.from_numpy(depths).float() + + return data + + +if __name__ == "__main__": + import argparse + + import imageio.v2 as imageio + import tqdm + + parser = argparse.ArgumentParser() + parser.add_argument("--data_dir", type=str, default="data/360_v2/garden") + parser.add_argument("--factor", type=int, default=4) + args = parser.parse_args() + + # Parse COLMAP data. + parser = Parser( + data_dir=args.data_dir, factor=args.factor, normalize=True, test_every=8 + ) + dataset = Dataset(parser, split="train", load_depths=True) + print(f"Dataset: {len(dataset)} images.") + + writer = imageio.get_writer("results/points.mp4", fps=30) + for data in tqdm.tqdm(dataset, desc="Plotting points"): + image = data["image"].numpy().astype(np.uint8) + points = data["points"].numpy() + depths = data["depths"].numpy() + for x, y in points: + cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1) + writer.append_data(image) + writer.close() diff --git a/pruning_utils/lib_bilagrid.py b/pruning_utils/lib_bilagrid.py new file mode 100644 index 0000000..1b10a20 --- /dev/null +++ b/pruning_utils/lib_bilagrid.py @@ -0,0 +1,573 @@ +# # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid. +To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory. + +For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/). + +#### Dependencies + +In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly). +We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2. + +#### Overview + +- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids + for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations. + +- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`. + +#### Examples + +- Bilateral grid for approximating ISP: + + Open In Colab + +- Low-rank 4D bilateral grid for MR enhancement: + + Open In Colab + + +Below is the API reference. + +""" + +import tensorly as tl +import torch +import torch.nn.functional as F +from torch import nn + +tl.set_backend("pytorch") + + +def color_correct( + img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255 +) -> torch.Tensor: + """ + Warp `img` to match the colors in `ref_img` using iterative color matching. + + This function performs color correction by warping the colors of the input image + to match those of a reference image. It uses a least squares method to find a + transformation that maps the input image's colors to the reference image's colors. + + The algorithm iteratively solves a system of linear equations, updating the set of + unsaturated pixels in each iteration. This approach helps handle non-linear color + transformations and reduces the impact of clipping. + + Args: + img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels] + ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels] + num_iters (int, optional): Number of iterations for the color matching process. + Default is 5. + eps (float, optional): Small value to determine the range of unclipped pixels. + Default is 0.5 / 255. + + Returns: + torch.Tensor: Color corrected image with the same shape as the input image. + + Note: + - Both input and reference images should be in the range [0, 1]. + - The function works with any number of channels, but typically used with 3 (RGB). + """ + if img.shape[-1] != ref.shape[-1]: + raise ValueError( + f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match" + ) + num_channels = img.shape[-1] + img_mat = img.reshape([-1, num_channels]) + ref_mat = ref.reshape([-1, num_channels]) + + def is_unclipped(z): + return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps]. + + mask0 = is_unclipped(img_mat) + # Because the set of saturated pixels may change after solving for a + # transformation, we repeatedly solve a system `num_iters` times and update + # our estimate of which pixels are saturated. + for _ in range(num_iters): + # Construct the left hand side of a linear system that contains a quadratic + # expansion of each pixel of `img`. + a_mat = [] + for c in range(num_channels): + a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term. + a_mat.append(img_mat) # Linear term. + a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. + a_mat = torch.cat(a_mat, dim=-1) + warp = [] + for c in range(num_channels): + # Construct the right hand side of a linear system containing each color + # of `ref`. + b = ref_mat[:, c] + # Ignore rows of the linear system that were saturated in the input or are + # saturated in the current corrected color estimate. + mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) + ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) + mb = torch.where(mask, b, torch.zeros_like(b)) + w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] + assert torch.all(torch.isfinite(w)) + warp.append(w) + warp = torch.stack(warp, dim=-1) + # Apply the warp to update img_mat. + img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1) + corrected_img = torch.reshape(img_mat, img.shape) + return corrected_img + + +def bilateral_grid_tv_loss(model, config): + """Computes total variations of bilateral grids.""" + total_loss = 0.0 + + for bil_grids in model.bil_grids: + total_loss += config.bilgrid_tv_loss_mult * total_variation_loss( + bil_grids.grids + ) + + return total_loss + + +def color_affine_transform(affine_mats, rgb): + """Applies color affine transformations. + + Args: + affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$. + rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$. + + Returns: + Output transformed colors of shape $(..., 3)$. + """ + return ( + torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1) + + affine_mats[..., 3] + ) + + +def _num_tensor_elems(t): + return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0) + + +def total_variation_loss(x): # noqa: F811 + """Returns total variation on multi-dimensional tensors. + + Args: + x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size. + """ + batch_size = x.shape[0] + tv = 0 + for i in range(2, len(x.shape)): + n_res = x.shape[i] + idx1 = torch.arange(1, n_res, device=x.device) + idx2 = torch.arange(0, n_res - 1, device=x.device) + x1 = x.index_select(i, idx1) + x2 = x.index_select(i, idx2) + count = _num_tensor_elems(x1) + tv += torch.pow((x1 - x2), 2).sum() / count + return tv / batch_size + + +def slice(bil_grids, xy, rgb, grid_idx): + """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`. + + Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size + and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`. + + The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and + the output color `rgb_out` after applying the afffine transformations. + + In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor. + Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`. + For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case. + + .. note:: + This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement. + When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not + perform tensor indexing to avoid data copy and extra memory + (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)). + + Args: + bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids. + xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$. + rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$. + grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$. + + Returns: + A dictionary with keys and values as follows: + ``` + { + "rgb": Transformed RGB colors. Shape: (..., 3), + "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4) + } + ``` + """ + + sh_ = rgb.shape + + grid_idx_unique = torch.unique(grid_idx) + if len(grid_idx_unique) == 1: + # All pixels are from a single view. + grid_idx = grid_idx_unique # (1,) + xy = xy.unsqueeze(0) # (1, ..., 2) + rgb = rgb.unsqueeze(0) # (1, ..., 3) + else: + # Pixels are randomly sampled from different views. + if len(grid_idx.shape) == 4: + grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,) + elif len(grid_idx.shape) == 3: + grid_idx = grid_idx[:, 0, 0] # (chunk_size,) + elif len(grid_idx.shape) == 2: + grid_idx = grid_idx[:, 0] # (chunk_size,) + else: + raise ValueError( + "The input to bilateral grid slicing is not supported yet." + ) + + affine_mats = bil_grids(xy, rgb, grid_idx) + rgb = color_affine_transform(affine_mats, rgb) + + return { + "rgb": rgb.reshape(*sh_), + "rgb_affine_mats": affine_mats.reshape( + *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1] + ), + } + + +class BilateralGrid(nn.Module): + """Class for 3D bilateral grids. + + Holds one or more than one bilateral grids. + """ + + def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8): + """ + Args: + num (int): The number of bilateral grids (i.e., the number of views). + grid_X (int): Defines grid width $W$. + grid_Y (int): Defines grid height $H$. + grid_W (int): Defines grid guidance dimension $L$. + """ + super(BilateralGrid, self).__init__() + + self.grid_width = grid_X + """Grid width. Type: int.""" + self.grid_height = grid_Y + """Grid height. Type: int.""" + self.grid_guidance = grid_W + """Grid guidance dimension. Type: int.""" + + # Initialize grids. + grid = self._init_identity_grid() + self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W) + """ A 5-D tensor of shape $(N, 12, L, H, W)$.""" + + # Weights of BT601 RGB-to-gray. + self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])) + self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 + """ A function that converts RGB to gray-scale guidance in $[-1, 1]$.""" + + def _init_identity_grid(self): + grid = torch.tensor( + [ + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + ] + ).float() + grid = grid.repeat( + [self.grid_guidance * self.grid_height * self.grid_width, 1] + ) # (L * H * W, 12) + grid = grid.reshape( + 1, self.grid_guidance, self.grid_height, self.grid_width, -1 + ) # (1, L, H, W, 12) + grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W) + return grid + + def tv_loss(self): + """Computes and returns total variation loss on the bilateral grids.""" + return total_variation_loss(self.grids) + + def forward(self, grid_xy, rgb, idx=None): + """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input. + For the 2-D, 3-D, and 4-D cases, please refer to `slice`. + For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be + equal to the number of bilateral grids. Then this function becomes PyTorch's + [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html). + + Args: + grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$. + rgb (torch.Tensor): The RGB values in the range of $[0,1]$. + idx (torch.Tensor): The bilateral grid indices. + + Returns: + Sliced affine matrices of shape $(..., 3, 4)$. + """ + + grids = self.grids + input_ndims = len(grid_xy.shape) + assert len(rgb.shape) == input_ndims + + if input_ndims > 1 and input_ndims < 5: + # Convert input into 5D + for i in range(5 - input_ndims): + grid_xy = grid_xy.unsqueeze(1) + rgb = rgb.unsqueeze(1) + assert idx is not None + elif input_ndims != 5: + raise ValueError( + "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs" + ) + + grids = self.grids + if idx is not None: + grids = grids[idx] + assert grids.shape[0] == grid_xy.shape[0] + + # Generate slicing coordinates. + grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1]. + grid_z = self.rgb2gray(rgb) + + # print(grid_xy.shape, grid_z.shape) + # exit() + grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3) + + affine_mats = F.grid_sample( + grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border" + ) # (N, 12, m, h, w) + affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12) + affine_mats = affine_mats.reshape( + *affine_mats.shape[:-1], 3, 4 + ) # (N, m, h, w, 3, 4) + + for _ in range(5 - input_ndims): + affine_mats = affine_mats.squeeze(1) + + return affine_mats + + +def slice4d(bil_grid4d, xyz, rgb): + """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`. + + Args: + bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid. + xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. + rgb (torch.Tensor): The RGB values with shape $(..., 3)$. + + Returns: + A dictionary with keys and values as follows: + ``` + { + "rgb": Transformed radiance RGB colors. Shape: (..., 3), + "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4) + } + ``` + """ + + affine_mats = bil_grid4d(xyz, rgb) + rgb = color_affine_transform(affine_mats, rgb) + + return {"rgb": rgb, "rgb_affine_mats": affine_mats} + + +class _ScaledTanh(nn.Module): + def __init__(self, s=2.0): + super().__init__() + self.scaler = s + + def forward(self, x): + return torch.tanh(self.scaler * x) + + +class BilateralGridCP4D(nn.Module): + """Class for low-rank 4D bilateral grids.""" + + def __init__( + self, + grid_X=16, + grid_Y=16, + grid_Z=16, + grid_W=8, + rank=5, + learn_gray=True, + gray_mlp_width=8, + gray_mlp_depth=2, + init_noise_scale=1e-6, + bound=2.0, + ): + """ + Args: + grid_X (int): Defines grid width. + grid_Y (int): Defines grid height. + grid_Z (int): Defines grid depth. + grid_W (int): Defines grid guidance dimension. + rank (int): Rank of the 4D bilateral grid. + learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances. + gray_mlp_width (int): The MLP width for learnable guidance. + gray_mlp_depth (int): The number of MLP layers for learnable guidance. + init_noise_scale (float): The noise scale of the initialized factors. + bound (float): The bound of the xyz coordinates. + """ + super(BilateralGridCP4D, self).__init__() + + self.grid_X = grid_X + """Grid width. Type: int.""" + self.grid_Y = grid_Y + """Grid height. Type: int.""" + self.grid_Z = grid_Z + """Grid depth. Type: int.""" + self.grid_W = grid_W + """Grid guidance dimension. Type: int.""" + self.rank = rank + """Rank of the 4D bilateral grid. Type: int.""" + self.learn_gray = learn_gray + """Flags of learnable guidance is used. Type: bool.""" + self.gray_mlp_width = gray_mlp_width + """The MLP width for learnable guidance. Type: int.""" + self.gray_mlp_depth = gray_mlp_depth + """The MLP depth for learnable guidance. Type: int.""" + self.init_noise_scale = init_noise_scale + """The noise scale of the initialized factors. Type: float.""" + self.bound = bound + """The bound of the xyz coordinates. Type: float.""" + + self._init_cp_factors_parafac() + + self.rgb2gray = None + """ A function that converts RGB to gray-scale guidances in $[-1, 1]$. + If `learn_gray` is True, this will be an MLP network.""" + + if self.learn_gray: + + def rgb2gray_mlp_linear(layer): + return nn.Linear( + self.gray_mlp_width, + self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1, + ) + + def rgb2gray_mlp_actfn(_): + return nn.ReLU(inplace=True) + + self.rgb2gray = nn.Sequential( + *( + [nn.Linear(3, self.gray_mlp_width)] + + [ + nn_module(layer) + for layer in range(1, self.gray_mlp_depth) + for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear] + ] + + [_ScaledTanh(2.0)] + ) + ) + else: + # Weights of BT601/BT470 RGB-to-gray. + self.register_buffer( + "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]) + ) + self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0 + + def _init_identity_grid(self): + grid = torch.tensor( + [ + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + 0, + 0, + 0, + 1.0, + 0, + ] + ).float() + grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1]) + grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1) + grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X) + return grid + + def _init_cp_factors_parafac(self): + # Initialize identity grids. + init_grids = self._init_identity_grid() + # Random noises are added to avoid singularity. + init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids + from tensorly.decomposition import parafac + + # Initialize grid CP factors + _, facs = parafac(init_grids.clone().detach(), rank=self.rank) + + self.num_facs = len(facs) + + self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False) + self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank) + + for i in range(1, self.num_facs): + fac = facs[i].T # (rank, grid_size) + fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1) + self.register_buffer(f"fac_{i}_init", fac) + + fac_resid = torch.zeros_like(fac) + self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid)) + + def tv_loss(self): + """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids.""" + + total_loss = 0 + for i in range(1, self.num_facs): + fac = self.get_parameter(f"fac_{i}") + total_loss += total_variation_loss(fac) + + return total_loss + + def forward(self, xyz, rgb): + """Low-rank 4D bilateral grid slicing. + + Args: + xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$. + rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$. + + Returns: + Sliced affine matrices with shape $(..., 3, 4)$. + """ + sh_ = xyz.shape + xyz = xyz.reshape(-1, 3) # flatten (N, 3) + rgb = rgb.reshape(-1, 3) # flatten (N, 3) + + xyz = xyz / self.bound + assert self.rgb2gray is not None + gray = self.rgb2gray(rgb) + xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4) + xyzw = xyzw.transpose(0, 1) # (4, N) + coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2) + coords = coords.unsqueeze(1) # (4, 1, N, 2) + + coef = 1.0 + for i in range(1, self.num_facs): + fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init") + coef = coef * F.grid_sample( + fac, coords[[i - 1]], align_corners=True, padding_mode="border" + ) # [1, rank, 1, N] + coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore + mat = self.fac_0(coef) + return mat.reshape(*sh_[:-1], 3, 4) diff --git a/pruning_utils/nerf.py b/pruning_utils/nerf.py new file mode 100644 index 0000000..07e4693 --- /dev/null +++ b/pruning_utils/nerf.py @@ -0,0 +1,229 @@ +import os +import json +import torch +from typing import Any, Dict, Optional +import imageio.v2 as imageio +import numpy as np +import cv2 + +from .normalize import ( + align_principle_axes, + similarity_from_cameras, + transform_cameras, +) + + +def convert_opengl_to_colmap(R): + M_flip = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) + return R @ M_flip + + +class NerfParser: + """NeRF Studio parser (transforms.json).""" + def __init__(self, data_dir: str, factor: int = 1, normalize: bool = False, test_every: int = 8): + self.data_dir = data_dir + self.factor = factor + self.normalize = normalize + self.test_every = test_every + + # Load transforms.json + transform_path = os.path.join(data_dir, "transforms.json") + if not os.path.exists(transform_path): + raise FileNotFoundError(f"Could not find transforms.json at {transform_path}") + with open(transform_path, "r") as f: + meta = json.load(f) + + # Store camera poses and intrinsics + camtoworlds = [] + image_paths = [] + depth_paths = [] + + camera_ids = [] + image_names = [] + + fl_x = meta["fl_x"] + fl_y = meta["fl_y"] + cx = meta["cx"] + cy = meta["cy"] + w = meta["w"] + h = meta["h"] + fx, fy = fl_x / factor, fl_y / factor + cx, cy = cx / factor, cy / factor + width, height = w // factor, h // factor + + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + cam_id = 0 # Single intrinsics shared across all images + + for i, frame in enumerate(meta["frames"]): + c2w = np.array(frame["transform_matrix"], dtype=np.float32) + camtoworlds.append(c2w) + + + + rel_image_path = frame["file_path"] + if not rel_image_path.lower().endswith((".png", ".jpg")): + rel_image_path += ".png" # default fallback + + image_path = os.path.join(data_dir, rel_image_path) + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image {image_path} not found.") + image_paths.append(image_path) + image_names.append(os.path.basename(rel_image_path)) + camera_ids.append(cam_id) + + + + rel_depth_path = frame.get("depth_file_path") + if rel_depth_path is None: + raise ValueError(f"No depth_file_path found for frame {i}") + if not rel_depth_path.lower().endswith((".png", ".jpg", ".exr")): + rel_depth_path += ".png" + depth_path = os.path.join(data_dir, rel_depth_path) + if not os.path.exists(depth_path): + raise FileNotFoundError(f"Depth image not found: {depth_path}") + depth_paths.append(depth_path) + + camtoworlds = np.stack(camtoworlds, axis=0) + image_paths = np.array(image_paths) + image_names = np.array(image_names) + depth_paths = np.array(depth_paths) + + + + + # Normalize world space if requested + if normalize: + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + + T2 = align_principle_axes(camtoworlds[:, :3, 3]) + camtoworlds = transform_cameras(T2, camtoworlds) + + transform = T2 @ T1 + else: + transform = np.eye(4) + + self.image_names = list(image_names) + self.image_paths = list(image_paths) + self.depth_paths = list(depth_paths) + self.camtoworlds = camtoworlds + self.camera_ids = camera_ids + self.Ks_dict = {cam_id: K} + self.params_dict = {cam_id: np.array([], dtype=np.float32)} # no distortion + self.imsize_dict = {cam_id: (width, height)} + self.mask_dict = {cam_id: None} + self.points = np.zeros((0, 3), dtype=np.float32) + self.points_err = np.zeros((0,), dtype=np.float32) + self.points_rgb = np.zeros((0, 3), dtype=np.uint8) + self.point_indices = {} # Not available in NeRF Studio format + self.transform = transform + + # Resize intrinsics if actual image differs from JSON + actual_image = imageio.imread(self.image_paths[0])[..., :3] + actual_height, actual_width = actual_image.shape[:2] + s_height, s_width = actual_height / height, actual_width / width + for camera_id, K in self.Ks_dict.items(): + K[0, :] *= s_width + K[1, :] *= s_height + self.Ks_dict[camera_id] = K + self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height)) + + # NeRF Studio has no distortion, so no undistortion mapping needed + self.mapx_dict = {} + self.mapy_dict = {} + self.roi_undist_dict = {} + + # Estimate scene scale from camera positions + camera_locations = self.camtoworlds[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + self.scene_scale = float(np.max(dists)) + + + print(f"[Parser] Loaded {len(self.image_paths)} NeRF Studio frames.") + + self.camtoworlds[:, :3, :3] = convert_opengl_to_colmap(self.camtoworlds[:, :3, :3]) + + +class NerfDataset: + """A simple dataset class.""" + + def __init__( + self, + parser: NerfParser, + split: str = "train", + patch_size: Optional[int] = None, + load_depths: bool = False, + ): + self.parser = parser + self.split = split + self.patch_size = patch_size + self.load_depths = load_depths + indices = np.arange(len(self.parser.image_names)) + if split == "train": + self.indices = indices[indices % self.parser.test_every != 0] + else: + self.indices = indices[indices % self.parser.test_every == 0] + + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item: int) -> Dict[str, Any]: + index = self.indices[item] + image = imageio.imread(self.parser.image_paths[index])[..., :3] + camera_id = self.parser.camera_ids[index] + K = self.parser.Ks_dict[camera_id].copy() # undistorted K + params = self.parser.params_dict[camera_id] + camtoworlds = self.parser.camtoworlds[index] + mask = self.parser.mask_dict[camera_id] + + if len(params) > 0: + # Images are distorted. Undistort them. + mapx, mapy = ( + self.parser.mapx_dict[camera_id], + self.parser.mapy_dict[camera_id], + ) + image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) + x, y, w, h = self.parser.roi_undist_dict[camera_id] + image = image[y : y + h, x : x + w] + + if self.patch_size is not None: + # Random crop. + h, w = image.shape[:2] + x = np.random.randint(0, max(w - self.patch_size, 1)) + y = np.random.randint(0, max(h - self.patch_size, 1)) + image = image[y : y + self.patch_size, x : x + self.patch_size] + K[0, 2] -= x + K[1, 2] -= y + + data = { + "K": torch.from_numpy(K).float(), + "camtoworld": torch.from_numpy(camtoworlds).float(), + "image": torch.from_numpy(image).float(), + "image_id": item, # the index of the image in the dataset + } + if mask is not None: + data["mask"] = torch.from_numpy(mask).bool() + + if self.load_depths: + # Get the full depth path from the parser + depth_filename = self.parser.depth_paths[index] + + if not os.path.exists(depth_filename): + raise FileNotFoundError(f"Depth file not found: {depth_filename}") + + # Load depth + depth = imageio.imread(depth_filename).astype(np.float32) + + # Apply patching if needed + if self.patch_size is not None: + depth = depth[y : y + self.patch_size, x : x + self.patch_size] + + if depth.shape[:2] != image.shape[:2]: + raise ValueError(f"Depth shape {depth.shape} doesn't match image shape {image.shape}") + + data["depth"] = torch.from_numpy(depth).float() + + return data + diff --git a/pruning_utils/normalize.py b/pruning_utils/normalize.py new file mode 100644 index 0000000..e0c4b20 --- /dev/null +++ b/pruning_utils/normalize.py @@ -0,0 +1,154 @@ +""" +=========================================================================== +Unmodified code from gsplat examples +Original source: +https://github.com/nerfstudio-project/gsplat/blob/main/examples/datasets/normalize.py +License: Apache License 2.0 +---------------------------------------------------------------------------- +This file was copied verbatim from the gsplat repository. No modifications were made. +=========================================================================== +""" + +import numpy as np + + +def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"): + """ + reference: nerf-factory + Get a similarity transform to normalize dataset + from c2w (OpenCV convention) cameras + :param c2w: (N, 4) + :return T (4,4) , scale (float) + """ + t = c2w[:, :3, 3] + R = c2w[:, :3, :3] + + # (1) Rotate the world so that z+ is the up axis + # we estimate the up axis by averaging the camera up axes + ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) + world_up = np.mean(ups, axis=0) + world_up /= np.linalg.norm(world_up) + + up_camspace = np.array([0.0, -1.0, 0.0]) + c = (up_camspace * world_up).sum() + cross = np.cross(world_up, up_camspace) + skew = np.array( + [ + [0.0, -cross[2], cross[1]], + [cross[2], 0.0, -cross[0]], + [-cross[1], cross[0], 0.0], + ] + ) + if c > -1: + R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) + else: + # In the unlikely case the original data has y+ up axis, + # rotate 180-deg about x axis + R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + + # R_align = np.eye(3) # DEBUG + R = R_align @ R + fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) + t = (R_align @ t[..., None])[..., 0] + + # (2) Recenter the scene. + if center_method == "focus": + # find the closest point to the origin for each camera's center ray + nearest = t + (fwds * -t).sum(-1)[:, None] * fwds + translate = -np.median(nearest, axis=0) + elif center_method == "poses": + # use center of the camera positions + translate = -np.median(t, axis=0) + else: + raise ValueError(f"Unknown center_method {center_method}") + + transform = np.eye(4) + transform[:3, 3] = translate + transform[:3, :3] = R_align + + # (3) Rescale the scene using camera distances + scale_fn = np.max if strict_scaling else np.median + scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1)) + transform[:3, :] *= scale + + return transform + + +def align_principle_axes(point_cloud): + # Compute centroid + centroid = np.median(point_cloud, axis=0) + + # Translate point cloud to centroid + translated_point_cloud = point_cloud - centroid + + # Compute covariance matrix + covariance_matrix = np.cov(translated_point_cloud, rowvar=False) + + # Compute eigenvectors and eigenvalues + eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix) + + # Sort eigenvectors by eigenvalues (descending order) so that the z-axis + # is the principal axis with the smallest eigenvalue. + sort_indices = eigenvalues.argsort()[::-1] + eigenvectors = eigenvectors[:, sort_indices] + + # Check orientation of eigenvectors. If the determinant of the eigenvectors is + # negative, then we need to flip the sign of one of the eigenvectors. + if np.linalg.det(eigenvectors) < 0: + eigenvectors[:, 0] *= -1 + + # Create rotation matrix + rotation_matrix = eigenvectors.T + + # Create SE(3) matrix (4x4 transformation matrix) + transform = np.eye(4) + transform[:3, :3] = rotation_matrix + transform[:3, 3] = -rotation_matrix @ centroid + + return transform + + +def transform_points(matrix, points): + """Transform points using an SE(3) matrix. + + Args: + matrix: 4x4 SE(3) matrix + points: Nx3 array of points + + Returns: + Nx3 array of transformed points + """ + assert matrix.shape == (4, 4) + assert len(points.shape) == 2 and points.shape[1] == 3 + return points @ matrix[:3, :3].T + matrix[:3, 3] + + +def transform_cameras(matrix, camtoworlds): + """Transform cameras using an SE(3) matrix. + + Args: + matrix: 4x4 SE(3) matrix + camtoworlds: Nx4x4 array of camera-to-world matrices + + Returns: + Nx4x4 array of transformed camera-to-world matrices + """ + assert matrix.shape == (4, 4) + assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4) + camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix) + scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1) + camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None] + return camtoworlds + + +def normalize(camtoworlds, points=None): + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + if points is not None: + points = transform_points(T1, points) + T2 = align_principle_axes(points) + camtoworlds = transform_cameras(T2, camtoworlds) + points = transform_points(T2, points) + return camtoworlds, points, T2 @ T1 + else: + return camtoworlds, T1 diff --git a/pruning_utils/open_ply_pipeline.py b/pruning_utils/open_ply_pipeline.py new file mode 100644 index 0000000..a6bd7d8 --- /dev/null +++ b/pruning_utils/open_ply_pipeline.py @@ -0,0 +1,232 @@ +import torch +from pyntcloud import PyntCloud +from typing import Dict, Union, Literal, Optional +import os +import pandas as pd + +def load_splats( + path: str, + device: Union[str, torch.device] = 'cpu' +) -> Dict[str, torch.nn.Parameter]: + """ + Load splat data from a .ply or .ckpt file. + + Args: + path (str): Path to the file (.ply or .ckpt). + device (str): Device to load tensors on. Default is 'cpu'. + + Returns: + splats (dict): Dictionary of torch.nn.Parameter containing splat data. + """ + ext = os.path.splitext(path)[1].lower() + + if ext == '.ply': + print("loading ply file") + cloud = PyntCloud.from_file(path) + data = cloud.points + + # Mapping from column names in the .ply to model parameter names + mapping = { + "_model.gauss_params.means": ["x", "y", "z"], + "_model.gauss_params.opacities": "opacity", + "_model.gauss_params.quats": ["rot_0", 'rot_1', 'rot_2', 'rot_3'], + "_model.gauss_params.scales": ["scale_0", "scale_1", "scale_2"], + "_model.gauss_params.features_dc": [f"f_dc_{i}" for i in range(3)], + "_model.gauss_params.features_rest": [f"f_rest_{i}" for i in range(45)] + } + + ckpt = {"pipeline": {}} + + for key, value in mapping.items(): + if isinstance(value, list): + ckpt["pipeline"][key] = torch.tensor(data[value].values, dtype=torch.float32) + else: + ckpt["pipeline"][key] = torch.tensor(data[value].values, dtype=torch.float32) + + # Reshape tensors + ckpt["pipeline"]["_model.gauss_params.means"] = ckpt["pipeline"]["_model.gauss_params.means"].reshape(-1, 3) + ckpt["pipeline"]["_model.gauss_params.opacities"] = ckpt["pipeline"]["_model.gauss_params.opacities"].reshape(-1) + ckpt["pipeline"]["_model.gauss_params.scales"] = ckpt["pipeline"]["_model.gauss_params.scales"].reshape(-1, 3) + ckpt["pipeline"]["_model.gauss_params.features_dc"] = ckpt["pipeline"]["_model.gauss_params.features_dc"].reshape(-1, 1, 3) + ckpt["pipeline"]["_model.gauss_params.features_rest"] = ckpt["pipeline"]["_model.gauss_params.features_rest"].reshape(-1, 15, 3) + + + + + # Mapping for renaming + rename_map = { + "_model.gauss_params.means": "means", + "_model.gauss_params.opacities": "opacities", + "_model.gauss_params.quats": "quats", + "_model.gauss_params.scales": "scales", + "_model.gauss_params.features_dc": "sh0", + "_model.gauss_params.features_rest": "shN" + } + + step = 0 # ply do not contain step information + + # Convert to splats dictionary + splats = { + new_key: torch.nn.Parameter(ckpt["pipeline"][old_key] if new_key == "sh0" + else ckpt["pipeline"][old_key].squeeze() if new_key == "opacities" + else ckpt["pipeline"][old_key]) + for old_key, new_key in rename_map.items() + } + + + + elif ext == '.ckpt': + print("loading ckpt file") + ckpt = torch.load(path, map_location=device) + + # Mapping for renaming + rename_map = { + "_model.gauss_params.means": "means", + "_model.gauss_params.opacities": "opacities", + "_model.gauss_params.quats": "quats", + "_model.gauss_params.scales": "scales", + "_model.gauss_params.features_dc": "sh0", + "_model.gauss_params.features_rest": "shN" + } + + step = ckpt["step"] + + # Convert to splats dictionary + splats = { + new_key: torch.nn.Parameter(ckpt["pipeline"][old_key].unsqueeze(1) if new_key == "sh0" + else ckpt["pipeline"][old_key].squeeze() if new_key == "opacities" + else ckpt["pipeline"][old_key]) + for old_key, new_key in rename_map.items() + } + + + elif ext == '.pt': + print("loading pt file") + + # Load and concatenate splats from checkpoints + path = [path] + ckpts = [ + torch.load(file, map_location=device, weights_only=True) for file in path + ] + + param_names = ckpts[0]["splats"].keys() + + step = ckpts["step"] + + splats = { + name: torch.nn.Parameter( + torch.cat([ckpt["splats"][name] for ckpt in ckpts], dim=0) + ) + for name in param_names + } + + + else: + raise ValueError(f"Unsupported file extension: {ext}") + + + return step, splats + + + +def save_splats( + path: str, + data: Dict[str, Union[Dict[str, torch.nn.Parameter], int]], + file_type: Literal['ply', 'pt', 'ckpt'] = 'ply', + step_value: Optional[int] = None +) -> None: + """ + Export splat data to .ply, .pt, or .ckpt format. + + Args: + path (str): Output path without file extension. + data (dict): Dictionary containing 'splats' and optionally 'step'. + file_type (str): One of 'ply', 'pt', or 'ckpt'. + step_value (int, optional): If provided with file_type='pt', a .ckpt is also saved. + """ + assert file_type in ['ply', 'pt', 'ckpt'], "Unsupported file type" + + if file_type == 'ply': + path += ".ply" + splats = data["splats"] + print("Saving to .ply format") + + export_data = { + "x": splats["means"][:, 0].detach().cpu().numpy(), + "y": splats["means"][:, 1].detach().cpu().numpy(), + "z": splats["means"][:, 2].detach().cpu().numpy(), + "opacity": splats["opacities"].detach().cpu().numpy(), + "rot_0": splats["quats"][:, 0].detach().cpu().numpy(), + "rot_1": splats["quats"][:, 1].detach().cpu().numpy(), + "rot_2": splats["quats"][:, 2].detach().cpu().numpy(), + "rot_3": splats["quats"][:, 3].detach().cpu().numpy(), + "scale_0": splats["scales"][:, 0].detach().cpu().numpy(), + "scale_1": splats["scales"][:, 1].detach().cpu().numpy(), + "scale_2": splats["scales"][:, 2].detach().cpu().numpy(), + } + + for i in range(3): + export_data[f"f_dc_{i}"] = splats["sh0"][:, 0, i].detach().cpu().numpy() + for i in range(45): + c = i % 3 + j = i // 3 + export_data[f"f_rest_{i}"] = splats["shN"][:, j, c].detach().cpu().numpy() + + df = pd.DataFrame(export_data) + cloud = PyntCloud(df) + cloud.to_file(path) + + elif file_type == 'pt': + pt_path = path + ".pt" + print(f"Saving to .pt format at {pt_path}") + torch.save(data, pt_path) + + if step_value is not None: + print("Also generating .ckpt file from .pt content") + _save_ckpt_from_splats(pt_path, path + ".ckpt", step_value) + + elif file_type == 'ckpt': + print(f"Saving to .ckpt format at {path + '.ckpt'}") + _save_ckpt_from_splats(None, path + ".ckpt", step_value, direct_data=data) + + +def _save_ckpt_from_splats( + pt_path: str, + save_path: str, + step_value: int, + direct_data: Optional[Dict[str, Dict[str, torch.nn.Parameter]]] = None +) -> None: + """Internal helper to convert .pt or dict to .ckpt.""" + device = torch.device("cpu") + + if direct_data is not None: + splats = direct_data["splats"] + else: + pt_data = torch.load(pt_path, map_location=device) + splats = pt_data["splats"] + + inverse_rename_map = { + "means": "_model.gauss_params.means", + "opacities": "_model.gauss_params.opacities", + "quats": "_model.gauss_params.quats", + "scales": "_model.gauss_params.scales", + "sh0": "_model.gauss_params.features_dc", + "shN": "_model.gauss_params.features_rest" + } + + pipeline = {} + for new_key, old_key in inverse_rename_map.items(): + tensor = splats[new_key].data + if new_key == "sh0": + tensor = tensor.squeeze(1) + elif new_key == "opacities": + tensor = tensor.unsqueeze(-1) + pipeline[old_key] = tensor + + ckpt = { + "step": step_value, + "pipeline": pipeline + } + + torch.save(ckpt, save_path) + print(f"Saved .ckpt to: {save_path}") diff --git a/pruning_utils/traj.py b/pruning_utils/traj.py new file mode 100644 index 0000000..9e8a2d6 --- /dev/null +++ b/pruning_utils/traj.py @@ -0,0 +1,254 @@ +""" +Code borrowed from + +https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py +""" + +import numpy as np +import scipy + + +def normalize(x: np.ndarray) -> np.ndarray: + """Normalization helper function.""" + return x / np.linalg.norm(x) + + +def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray: + """Construct lookat view matrix.""" + vec2 = normalize(lookdir) + vec0 = normalize(np.cross(up, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, position], axis=1) + return m + + +def focus_point_fn(poses: np.ndarray) -> np.ndarray: + """Calculate nearest point to all focal axes in poses.""" + directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] + m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) + mt_m = np.transpose(m, [0, 2, 1]) @ m + focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] + return focus_pt + + +def average_pose(poses: np.ndarray) -> np.ndarray: + """New pose using average position, z-axis, and up vector of input poses.""" + position = poses[:, :3, 3].mean(0) + z_axis = poses[:, :3, 2].mean(0) + up = poses[:, :3, 1].mean(0) + cam2world = viewmatrix(z_axis, up, position) + return cam2world + + +def generate_spiral_path( + poses, + bounds, + n_frames=120, + n_rots=2, + zrate=0.5, + spiral_scale_f=1.0, + spiral_scale_r=1.0, + focus_distance=0.75, +): + """Calculates a forward facing spiral path for rendering.""" + # Find a reasonable 'focus depth' for this dataset as a weighted average + # of conservative near and far bounds in disparity space. + near_bound = bounds.min() + far_bound = bounds.max() + # All cameras will point towards the world space point (0, 0, -focal). + focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound)) + focal = focal * spiral_scale_f + + # Get radii for spiral path using 90th percentile of camera positions. + positions = poses[:, :3, 3] + radii = np.percentile(np.abs(positions), 90, 0) + radii = radii * spiral_scale_r + radii = np.concatenate([radii, [1.0]]) + + # Generate poses for spiral path. + render_poses = [] + cam2world = average_pose(poses) + up = poses[:, :3, 1].mean(0) + for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False): + t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0] + position = cam2world @ t + lookat = cam2world @ [0, 0, -focal, 1.0] + z_axis = position - lookat + render_poses.append(viewmatrix(z_axis, up, position)) + render_poses = np.stack(render_poses, axis=0) + return render_poses + + +def generate_ellipse_path_z( + poses: np.ndarray, + n_frames: int = 120, + # const_speed: bool = True, + variation: float = 0.0, + phase: float = 0.0, + height: float = 0.0, +) -> np.ndarray: + """Generate an elliptical render path based on the given poses.""" + # Calculate the focal point for the path (cameras point toward this). + center = focus_point_fn(poses) + # Path height sits at z=height (in middle of zero-mean capture pattern). + offset = np.array([center[0], center[1], height]) + + # Calculate scaling for ellipse axes based on input camera positions. + sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) + # Use ellipse that is symmetric about the focal point in xy. + low = -sc + offset + high = sc + offset + # Optional height variation need not be symmetric + z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) + z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) + + def get_positions(theta): + # Interpolate between bounds with trig functions to get ellipse in x-y. + # Optionally also interpolate in z to change camera height along path. + return np.stack( + [ + low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), + low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5), + variation + * ( + z_low[2] + + (z_high - z_low)[2] + * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) + ) + + height, + ], + -1, + ) + + theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) + positions = get_positions(theta) + + # if const_speed: + # # Resample theta angles so that the velocity is closer to constant. + # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) + # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) + # positions = get_positions(theta) + + # Throw away duplicated last position. + positions = positions[:-1] + + # Set path's up vector to axis closest to average of input pose up vectors. + avg_up = poses[:, :3, 1].mean(0) + avg_up = avg_up / np.linalg.norm(avg_up) + ind_up = np.argmax(np.abs(avg_up)) + up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) + + return np.stack([viewmatrix(center - p, up, p) for p in positions]) + + +def generate_ellipse_path_y( + poses: np.ndarray, + n_frames: int = 120, + # const_speed: bool = True, + variation: float = 0.0, + phase: float = 0.0, + height: float = 0.0, +) -> np.ndarray: + """Generate an elliptical render path based on the given poses.""" + # Calculate the focal point for the path (cameras point toward this). + center = focus_point_fn(poses) + # Path height sits at y=height (in middle of zero-mean capture pattern). + offset = np.array([center[0], height, center[2]]) + + # Calculate scaling for ellipse axes based on input camera positions. + sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) + # Use ellipse that is symmetric about the focal point in xy. + low = -sc + offset + high = sc + offset + # Optional height variation need not be symmetric + y_low = np.percentile((poses[:, :3, 3]), 10, axis=0) + y_high = np.percentile((poses[:, :3, 3]), 90, axis=0) + + def get_positions(theta): + # Interpolate between bounds with trig functions to get ellipse in x-z. + # Optionally also interpolate in y to change camera height along path. + return np.stack( + [ + low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5), + variation + * ( + y_low[1] + + (y_high - y_low)[1] + * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5) + ) + + height, + low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5), + ], + -1, + ) + + theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True) + positions = get_positions(theta) + + # if const_speed: + # # Resample theta angles so that the velocity is closer to constant. + # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) + # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) + # positions = get_positions(theta) + + # Throw away duplicated last position. + positions = positions[:-1] + + # Set path's up vector to axis closest to average of input pose up vectors. + avg_up = poses[:, :3, 1].mean(0) + avg_up = avg_up / np.linalg.norm(avg_up) + ind_up = np.argmax(np.abs(avg_up)) + up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) + + return np.stack([viewmatrix(p - center, up, p) for p in positions]) + + +def generate_interpolated_path( + poses: np.ndarray, + n_interp: int, + spline_degree: int = 5, + smoothness: float = 0.03, + rot_weight: float = 0.1, +): + """Creates a smooth spline path between input keyframe camera poses. + + Spline is calculated with poses in format (position, lookat-point, up-point). + + Args: + poses: (n, 3, 4) array of input pose keyframes. + n_interp: returned path will have n_interp * (n - 1) total poses. + spline_degree: polynomial degree of B-spline. + smoothness: parameter for spline smoothing, 0 forces exact interpolation. + rot_weight: relative weighting of rotation/translation in spline solve. + + Returns: + Array of new camera poses with shape (n_interp * (n - 1), 3, 4). + """ + + def poses_to_points(poses, dist): + """Converts from pose matrices to (position, lookat, up) format.""" + pos = poses[:, :3, -1] + lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] + up = poses[:, :3, -1] + dist * poses[:, :3, 1] + return np.stack([pos, lookat, up], 1) + + def points_to_poses(points): + """Converts from (position, lookat, up) format to pose matrices.""" + return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) + + def interp(points, n, k, s): + """Runs multidimensional B-spline interpolation on the input points.""" + sh = points.shape + pts = np.reshape(points, (sh[0], -1)) + k = min(k, sh[0] - 1) + tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) + u = np.linspace(0, 1, n, endpoint=False) + new_points = np.array(scipy.interpolate.splev(u, tck)) + new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) + return new_points + + points = poses_to_points(poses, dist=rot_weight) + new_points = interp( + points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness + ) + return points_to_poses(new_points) diff --git a/pruning_utils/utils.py b/pruning_utils/utils.py new file mode 100644 index 0000000..152b9ee --- /dev/null +++ b/pruning_utils/utils.py @@ -0,0 +1,235 @@ +""" +=========================================================================== +Unmodified code from gsplat examples +Original source: +https://github.com/nerfstudio-project/gsplat/blob/main/examples/utils.py +License: Apache License 2.0 +---------------------------------------------------------------------------- +This file was copied verbatim from the gsplat repository. No modifications were made. +=========================================================================== +""" + +import random + +import numpy as np +import torch +from sklearn.neighbors import NearestNeighbors +from torch import Tensor +import torch.nn.functional as F +import matplotlib.pyplot as plt +from matplotlib import colormaps + + +class CameraOptModule(torch.nn.Module): + """Camera pose optimization module.""" + + def __init__(self, n: int): + super().__init__() + # Delta positions (3D) + Delta rotations (6D) + self.embeds = torch.nn.Embedding(n, 9) + # Identity rotation in 6D representation + self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) + + def zero_init(self): + torch.nn.init.zeros_(self.embeds.weight) + + def random_init(self, std: float): + torch.nn.init.normal_(self.embeds.weight, std=std) + + def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: + """Adjust camera pose based on deltas. + + Args: + camtoworlds: (..., 4, 4) + embed_ids: (...,) + + Returns: + updated camtoworlds: (..., 4, 4) + """ + assert camtoworlds.shape[:-2] == embed_ids.shape + batch_shape = camtoworlds.shape[:-2] + pose_deltas = self.embeds(embed_ids) # (..., 9) + dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] + rot = rotation_6d_to_matrix( + drot + self.identity.expand(*batch_shape, -1) + ) # (..., 3, 3) + transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1)) + transform[..., :3, :3] = rot + transform[..., :3, 3] = dx + return torch.matmul(camtoworlds, transform) + + +class AppearanceOptModule(torch.nn.Module): + """Appearance optimization module.""" + + def __init__( + self, + n: int, + feature_dim: int, + embed_dim: int = 16, + sh_degree: int = 3, + mlp_width: int = 64, + mlp_depth: int = 2, + ): + super().__init__() + self.embed_dim = embed_dim + self.sh_degree = sh_degree + self.embeds = torch.nn.Embedding(n, embed_dim) + layers = [] + layers.append( + torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) + ) + layers.append(torch.nn.ReLU(inplace=True)) + for _ in range(mlp_depth - 1): + layers.append(torch.nn.Linear(mlp_width, mlp_width)) + layers.append(torch.nn.ReLU(inplace=True)) + layers.append(torch.nn.Linear(mlp_width, 3)) + self.color_head = torch.nn.Sequential(*layers) + + def forward( + self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int + ) -> Tensor: + """Adjust appearance based on embeddings. + + Args: + features: (N, feature_dim) + embed_ids: (C,) + dirs: (C, N, 3) + + Returns: + colors: (C, N, 3) + """ + from gsplat.cuda._torch_impl import _eval_sh_bases_fast + + C, N = dirs.shape[:2] + # Camera embeddings + if embed_ids is None: + embeds = torch.zeros(C, self.embed_dim, device=features.device) + else: + embeds = self.embeds(embed_ids) # [C, D2] + embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2] + # GS features + features = features[None, :, :].expand(C, -1, -1) # [C, N, D1] + # View directions + dirs = F.normalize(dirs, dim=-1) # [C, N, 3] + num_bases_to_use = (sh_degree + 1) ** 2 + num_bases = (self.sh_degree + 1) ** 2 + sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K] + sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) + # Get colors + if self.embed_dim > 0: + h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K] + else: + h = torch.cat([features, sh_bases], dim=-1) + colors = self.color_head(h) + return colors + + +def rotation_6d_to_matrix(d6: Tensor) -> Tensor: + """ + Converts 6D rotation representation by Zhou et al. [1] to rotation matrix + using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. + Args: + d6: 6D rotation representation, of size (*, 6) + + Returns: + batch of rotation matrices of size (*, 3, 3) + + [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. + On the Continuity of Rotation Representations in Neural Networks. + IEEE Conference on Computer Vision and Pattern Recognition, 2019. + Retrieved from http://arxiv.org/abs/1812.07035 + """ + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + +def knn(x: Tensor, K: int = 4) -> Tensor: + x_np = x.cpu().numpy() + model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) + distances, _ = model.kneighbors(x_np) + return torch.from_numpy(distances).to(x) + + +def rgb_to_sh(rgb: Tensor) -> Tensor: + C0 = 0.28209479177387814 + return (rgb - 0.5) / C0 + + +def set_random_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +# ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163 +def colormap(img, cmap="jet"): + W, H = img.shape[:2] + dpi = 300 + fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) + im = ax.imshow(img, cmap=cmap) + ax.set_axis_off() + fig.colorbar(im, ax=ax) + fig.tight_layout() + fig.canvas.draw() + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = torch.from_numpy(data).float().permute(2, 0, 1) + plt.close() + return img + + +def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: + """Convert single channel to a color img. + + Args: + img (torch.Tensor): (..., 1) float32 single channel image. + colormap (str): Colormap for img. + + Returns: + (..., 3) colored img with colors in [0, 1]. + """ + img = torch.nan_to_num(img, 0) + if colormap == "gray": + return img.repeat(1, 1, 3) + img_long = (img * 255).long() + img_long_min = torch.min(img_long) + img_long_max = torch.max(img_long) + assert img_long_min >= 0, f"the min value is {img_long_min}" + assert img_long_max <= 255, f"the max value is {img_long_max}" + return torch.tensor( + colormaps[colormap].colors, # type: ignore + device=img.device, + )[img_long[..., 0]] + + +def apply_depth_colormap( + depth: torch.Tensor, + acc: torch.Tensor = None, + near_plane: float = None, + far_plane: float = None, +) -> torch.Tensor: + """Converts a depth image to color for easier analysis. + + Args: + depth (torch.Tensor): (..., 1) float32 depth. + acc (torch.Tensor | None): (..., 1) optional accumulation mask. + near_plane: Closest depth to consider. If None, use min image value. + far_plane: Furthest depth to consider. If None, use max image value. + + Returns: + (..., 3) colored depth image with colors in [0, 1]. + """ + near_plane = near_plane or float(torch.min(depth)) + far_plane = far_plane or float(torch.max(depth)) + depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) + depth = torch.clip(depth, 0.0, 1.0) + img = apply_float_colormap(depth, colormap="turbo") + if acc is not None: + img = img * acc + (1.0 - acc) + return img