diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..7d626da
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,34 @@
+FROM ghcr.io/nerfstudio-project/nerfstudio
+
+# Install system dependencies
+RUN apt-get update && \
+ apt-get install -y git gcc-10 g++-10 nvidia-cuda-toolkit ninja-build python3-pip wget && \
+ rm -rf /var/lib/apt/lists/*
+
+# Set environment variables to use gcc-10
+ENV CC=gcc-10
+ENV CXX=g++-10
+ENV CUDA_HOME="/usr/local/cuda"
+ENV CMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.utils.cmake_prefix_path)')"
+ENV TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6+PTX"
+
+# Clone QED-Splatter and install in editable mode
+RUN git clone https://github.com/leggedrobotics/qed-splatter.git /apps/qed-splatter
+RUN pip install /apps/qed-splatter
+
+# Register QED-Splatter with nerfstudio
+RUN ns-install-cli
+RUN mkdir -p /usr/local/cuda/bin && ln -s /usr/bin/nvcc /usr/local/cuda/bin/nvcc
+
+# Install additional Python dependencies
+RUN pip install git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e
+RUN pip install git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157
+RUN pip install nerfview pyntcloud
+
+# Pre-download AlexNet pretrained weights to prevent runtime downloading
+RUN mkdir -p /root/.cache/torch/hub/checkpoints && \
+ wget https://download.pytorch.org/models/alexnet-owt-7be5be79.pth \
+ -O /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
+
+# Set working directory
+WORKDIR /workspace
diff --git a/README.md b/README.md
index d72cafe..b862c65 100644
--- a/README.md
+++ b/README.md
@@ -32,3 +32,121 @@ To train the new method, use the following command:
```
ns-train qed-splatter --data [PATH]
```
+
+## Pruning Extension
+
+The pruning extension provides tools to reduce the number of Gaussians in order to improve rendering speed. There are two main types of pruners available:
+
+- **Soft pruners** gradually reduce the number of Gaussians during training.
+- **Hard pruners** are post-processing tools applied after training is complete.
+
+Each pruner computes a *pruning score* to evaluate the importance of individual Gaussians. The least important Gaussians are then removed.
+
+Currently, two hard pruning scripts are available: `rgb_hard_pruner` and `depth_hard_pruner`.
+
+### RGB_hard_pruner
+This pruner uses RGB loss to compute a pruning score to do hard pruning.
+```
+python3 RGB_hard_pruner.py default --data-dir datasets/park --ckpt results/park/step-000029999.ckpt --pruning-ratio 0.1 --result-dir output
+
+--eval-only (only evaluates, no saving, no pruning)
+--pruning-ratio 0.0 (no pruning, saved in new format)
+--output-format (ply (default), ckpt (nerfstudio), pt (gsplat))
+```
+
+## π₯ Required Arguments
+
+| Argument | Description |
+|-------------------|-----------------------------------------------------------------------------|
+| `default` | Specifies the run configuration. |
+| `--data-dir` | Path to the directory containing `transforms.json` (camera poses and intrinsics) and images (RGB, Depth)Β· |
+| `--ckpt` | Path to the pretrained model checkpoint (e.g., `results/park/step-XXXXX.ckpt`). |
+| `--pruning-ratio` | Float between `0.0` and `1.0`. Proportion of the model to prune. Example: `0.1` = keep 90%. |
+| `--result-dir` | Directory where the output (pruned model) will be saved. |
+
+
+## Input Format
+
+The code supports multiple output formats. The format is detected automatically.
+- `ply` : expects a Nerfstudio format for the transforms.
+- `ckpt` : expects a Nerfstudio format for the transforms.
+- `pt` : expects a gsplat format for the transforms.
+
+
+
+### GSPlat Dataset Format
+
+This repository expects datasets to be structured in a COLMAP-like format, which includes camera parameters, image poses, and optionally 3D points. This format is commonly used for 3D reconstruction and novel view synthesis tasks.
+
+#### π Folder Structure
+
+Your dataset should be organized like this:
+```
+data_dir/
+βββ images/ # All input images
+β βββ img1.jpg
+β βββ img2.png
+β βββ ...
+βββ sparse/ # Sparse reconstruction data (from COLMAP)
+β βββ cameras.bin # Camera intrinsics
+β βββ images.bin # Image poses (extrinsic) and filenames
+β βββ points3D.bin # Optional: 3D point cloud
+```
+
+### Nerfstudio Dataset Format
+
+#### π§ `transforms.json` File
+
+This file must include the following:
+
+- Intrinsic camera parameters:
+ - `"fl_x"`, `"fl_y"`: focal lengths
+ - `"cx"`, `"cy"`: principal point
+ - `"w"`, `"h"`: image dimensions
+
+- A list of frames, each containing:
+ - `file_path`: path to the RGB image (relative to `your_dataset/`)
+ - `depth_file_path`: path to the depth map (relative to `your_dataset/`)
+ - `transform_matrix`: 4x4 camera-to-world matrix
+
+**Example:**
+```json
+{
+ "w": 1920,
+ "h": 1080,
+ "fl_x": 2198.997802734375,
+ "fl_y": 2198.997802734375,
+ "cx": 960.0,
+ "cy": 540.0,
+ "k1": 0,
+ "k2": 0,
+ "p1": 0,
+ "p2": 0,
+ "frames": [
+ {
+ "file_path": "images/frame_0000.png",
+ "depth_file_path": "depths/frame_0000.png",
+ "transform_matrix": [[1,0,0,0], [0,1,0,0], [0,0,1,0], [0,0,0,1]]
+ },
+ {
+ "file_path": "images/frame_0001.png",
+ "depth_file_path": "depths/frame_0001.png",
+ "transform_matrix": [[0,2,0,0], [0,1,3,0], [0,0,1,0], [0,5,0,1]]
+ }
+ ]
+}
+```
+
+
+### Depth_hard_pruner
+This pruner uses depth loss to compute a pruning score to do hard pruning. It works analogously to the RGB hard pruner but not all features are available.
+```
+python3 depth_hard_pruner.py default --data-dir datasets/park --ckpt results/park/step-000029999.ckpt --pruning-ratio 0.1 --result-dir output
+
+--eval-only (only evaluates, no saving, no pruning)
+--pruning-ratio 0.0 (no pruning, saved in new format)
+--output-format (ply (default), ckpt (nerfstudio), pt (gsplat))
+```
+
+#### Known Issues
+For the Park scene it tries to generate black gaussians to cover the sky. The enitre scene is encased in these gaussians.
diff --git a/RGB_hard_pruner.py b/RGB_hard_pruner.py
new file mode 100644
index 0000000..e5efafb
--- /dev/null
+++ b/RGB_hard_pruner.py
@@ -0,0 +1,887 @@
+"""
+RGB hard pruning Script
+
+Part of this code is based on gsplatβs `simple_trainer.py`:
+https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
+
+"""
+import math
+import os
+import time
+from dataclasses import dataclass, field
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
+
+import imageio
+import nerfview
+import numpy as np
+import torch
+import torch.nn.functional as F
+import tqdm
+import tyro
+import viser
+from pruning_utils.nerf import NerfDataset , NerfParser
+from pruning_utils.colmap import Dataset, Parser
+from pruning_utils.traj import (
+ generate_interpolated_path,
+ generate_ellipse_path_z,
+ generate_spiral_path,
+)
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
+from fused_ssim import fused_ssim
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from typing_extensions import Literal, assert_never
+from pruning_utils.utils import AppearanceOptModule, CameraOptModule, set_random_seed
+from pruning_utils.lib_bilagrid import (
+ BilateralGrid,
+ color_correct,
+)
+from pruning_utils.open_ply_pipeline import load_splats, save_splats
+
+
+from gsplat.compression import PngCompression
+from gsplat.distributed import cli
+from gsplat.rendering import rasterization
+from gsplat.strategy import DefaultStrategy, MCMCStrategy
+from gsplat.strategy.ops import remove
+
+
+
+@dataclass
+class Config:
+ # Disable viewer
+ disable_viewer: bool = False
+ # Path to the .pt files. If provide, it will skip training and run evaluation only.
+ ckpt: Optional[List[str]] = None
+ # Name of compression strategy to use
+ compression: Optional[Literal["png"]] = None
+ # Render trajectory path
+ render_traj_path: str = "interp"
+
+ # Pruning Ratio
+ pruning_ratio: float = 0.0
+ # Output data format converted
+ output_format: str = "ply"
+ # Path to the Mip-NeRF 360 dataset
+ data_dir: str = "data"
+ # Downsample factor for the dataset
+ data_factor: int = 4
+ # Directory to save results
+ result_dir: str = "./results"
+ # Every N images there is a test image
+ test_every: int = 8
+ # Random crop size for training (experimental)
+ patch_size: Optional[int] = None
+ # A global scaler that applies to the scene size related parameters
+ global_scale: float = 1.0
+ # Normalize the world space
+ normalize_world_space: bool = True
+ # Camera model
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
+
+ # Port for the viewer server
+ port: int = 8080
+
+ # Batch size for training. Learning rates are scaled automatically
+ batch_size: int = 1
+ # A global factor to scale the number of training steps
+ steps_scaler: float = 1.0
+
+ # Number of training steps
+ max_steps: int = 10
+ # Steps to evaluate the model
+ eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+ # Steps to save the model
+ save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+
+ # Degree of spherical harmonics
+ sh_degree: int = 3
+ # Turn on another SH degree every this steps
+ sh_degree_interval: int = 1000
+ # Initial opacity of GS
+ init_opa: float = 0.1
+ # Initial scale of GS
+ init_scale: float = 1.0
+ # Weight for SSIM loss
+ ssim_lambda: float = 0.2
+
+ # Near plane clipping distance
+ near_plane: float = 0.01
+ # Far plane clipping distance
+ far_plane: float = 1e10
+
+ # Only run evalutation
+ eval_only: bool = False
+
+ # Strategy for GS densification
+ strategy: Union[DefaultStrategy, MCMCStrategy] = field(
+ default_factory=DefaultStrategy
+ )
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
+ packed: bool = False
+ # Use sparse gradients for optimization. (experimental)
+ sparse_grad: bool = False
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
+ antialiased: bool = False
+
+
+ # Opacity regularization
+ opacity_reg: float = 0.0
+ # Scale regularization
+ scale_reg: float = 0.0
+
+ # Enable camera optimization.
+ pose_opt: bool = False
+ # Learning rate for camera optimization
+ pose_opt_lr: float = 1e-5
+ # Regularization for camera optimization as weight decay
+ pose_opt_reg: float = 1e-6
+ # Add noise to camera extrinsics. This is only to test the camera pose optimization.
+ pose_noise: float = 0.0
+
+ # Enable appearance optimization. (experimental)
+ app_opt: bool = False
+ # Appearance embedding dimension
+ app_embed_dim: int = 16
+ # Learning rate for appearance optimization
+ app_opt_lr: float = 1e-3
+ # Regularization for appearance optimization as weight decay
+ app_opt_reg: float = 1e-6
+
+ # Enable bilateral grid. (experimental)
+ use_bilateral_grid: bool = False
+ # Shape of the bilateral grid (X, Y, W)
+ bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8)
+
+ # Enable depth loss. (experimental)
+ depth_loss: bool = False
+
+ lpips_net: Literal["vgg", "alex"] = "alex"
+
+ def adjust_steps(self, factor: float):
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
+ self.save_steps = [int(i * factor) for i in self.save_steps]
+ self.max_steps = int(self.max_steps * factor)
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
+
+ strategy = self.strategy
+ if isinstance(strategy, DefaultStrategy):
+ strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.reset_every = int(strategy.reset_every * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ elif isinstance(strategy, MCMCStrategy):
+ strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ else:
+ assert_never(strategy)
+
+
+
+
+class Runner:
+ """Engine for training and testing."""
+
+ def __init__(
+ self, local_rank: int, world_rank: int, world_size: int, cfg: Config
+ ) -> None:
+ set_random_seed(42 + local_rank)
+
+ self.cfg = cfg
+ self.world_rank = world_rank
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.device = f"cuda:{local_rank}"
+
+ # Where to dump results.
+ os.makedirs(cfg.result_dir, exist_ok=True)
+
+ # Setup output directories.
+ self.ckpt_dir = f"{cfg.result_dir}"
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+ self.stats_dir = f"{cfg.result_dir}/stats"
+ os.makedirs(self.stats_dir, exist_ok=True)
+ self.render_dir = f"{cfg.result_dir}/renders"
+ os.makedirs(self.render_dir, exist_ok=True)
+
+ # Tensorboard
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
+
+ # Load data: Training data should contain initial points and colors.
+ ext = os.path.splitext(cfg.ckpt[0])[1].lower()
+ if ext == ".ckpt" or ext == ".ply" :
+ print("this is a ckpt or ply file ")
+ self.parser = NerfParser(
+ data_dir=cfg.data_dir,
+ factor=cfg.data_factor,
+ normalize=cfg.normalize_world_space,
+ test_every=cfg.test_every,
+ )
+ self.trainset = NerfDataset(
+ self.parser,
+ split="train",
+ patch_size=cfg.patch_size,
+ load_depths=cfg.depth_loss,
+ )
+ elif ext == ".pt":
+ print("this is a pt file")
+ self.parser = Parser(
+ data_dir=cfg.data_dir,
+ factor=cfg.data_factor,
+ normalize=cfg.normalize_world_space,
+ test_every=cfg.test_every,
+ )
+ self.trainset = Dataset(
+ self.parser,
+ split="train",
+ patch_size=cfg.patch_size,
+ load_depths=cfg.depth_loss,
+ )
+ else:
+ msg ="Invalid Model type. Use .pt, .ckpt or .ply files."
+ raise TypeError(msg)
+
+ self.valset = Dataset(self.parser, split="val")
+ self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale
+
+ # Model
+ feature_dim = 32 if cfg.app_opt else None
+ self.splats, self.optimizers = self.create_splats_with_optimizers(
+ scene_scale=self.scene_scale,
+ sparse_grad=cfg.sparse_grad,
+ batch_size=cfg.batch_size,
+ device=self.device,
+ world_size=world_size,
+ cfg=self.cfg,
+ )
+
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
+
+ # Densification Strategy
+ self.cfg.strategy.check_sanity(self.splats, self.optimizers)
+
+ if isinstance(self.cfg.strategy, DefaultStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state(
+ scene_scale=self.scene_scale
+ )
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state()
+ else:
+ assert_never(self.cfg.strategy)
+
+ # Compression Strategy
+ self.compression_method = None
+ if cfg.compression is not None:
+ if cfg.compression == "png":
+ self.compression_method = PngCompression()
+ else:
+ raise ValueError(f"Unknown compression strategy: {cfg.compression}")
+
+ self.pose_optimizers = []
+ if cfg.pose_opt:
+ self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_adjust.zero_init()
+ self.pose_optimizers = [
+ torch.optim.Adam(
+ self.pose_adjust.parameters(),
+ lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size),
+ weight_decay=cfg.pose_opt_reg,
+ )
+ ]
+ if world_size > 1:
+ self.pose_adjust = DDP(self.pose_adjust)
+
+ if cfg.pose_noise > 0.0:
+ self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_perturb.random_init(cfg.pose_noise)
+ if world_size > 1:
+ self.pose_perturb = DDP(self.pose_perturb)
+
+ self.app_optimizers = []
+ if cfg.app_opt:
+ assert feature_dim is not None
+ self.app_module = AppearanceOptModule(
+ len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree
+ ).to(self.device)
+ # initialize the last layer to be zero so that the initial output is zero.
+ torch.nn.init.zeros_(self.app_module.color_head[-1].weight)
+ torch.nn.init.zeros_(self.app_module.color_head[-1].bias)
+ self.app_optimizers = [
+ torch.optim.Adam(
+ self.app_module.embeds.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0,
+ weight_decay=cfg.app_opt_reg,
+ ),
+ torch.optim.Adam(
+ self.app_module.color_head.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size),
+ ),
+ ]
+ if world_size > 1:
+ self.app_module = DDP(self.app_module)
+
+ self.bil_grid_optimizers = []
+ if cfg.use_bilateral_grid:
+ self.bil_grids = BilateralGrid(
+ len(self.trainset),
+ grid_X=cfg.bilateral_grid_shape[0],
+ grid_Y=cfg.bilateral_grid_shape[1],
+ grid_W=cfg.bilateral_grid_shape[2],
+ ).to(self.device)
+ self.bil_grid_optimizers = [
+ torch.optim.Adam(
+ self.bil_grids.parameters(),
+ lr=2e-3 * math.sqrt(cfg.batch_size),
+ eps=1e-15,
+ ),
+ ]
+
+ # Losses & Metrics.
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device)
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
+
+ if cfg.lpips_net == "alex":
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="alex", normalize=True
+ ).to(self.device)
+ elif cfg.lpips_net == "vgg":
+ # The 3DGS official repo uses lpips vgg, which is equivalent with the following:
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="vgg", normalize=False
+ ).to(self.device)
+ else:
+ raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
+
+ # Viewer
+ if not self.cfg.disable_viewer:
+ self.server = viser.ViserServer(port=cfg.port, verbose=False)
+ self.viewer = nerfview.Viewer(
+ server=self.server,
+ render_fn=self._viewer_render_fn,
+ mode="training",
+ )
+
+ def create_splats_with_optimizers(
+ self,
+ scene_scale: float = 1.0,
+ sparse_grad: bool = False,
+ batch_size: int = 1,
+ device: str = "cuda",
+ world_size: int = 1,
+ cfg: Optional[List[str]] = None,
+ ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
+
+
+
+ self.steps, splats = load_splats(cfg.ckpt[0], device)
+
+ # Convert to ParameterDict on the correct device
+ splats = torch.nn.ParameterDict(splats).to(device)
+
+
+ # Learning rates: you need to define them since theyβre not stored in the ckpt
+ # Use default values from above
+ default_lrs = {
+ "means": 1.6e-4 * scene_scale,
+ "scales": 5e-3,
+ "quats": 1e-3,
+ "opacities": 5e-2,
+ "sh0": 2.5e-3,
+ "shN": 2.5e-3 / 20,
+ "features": 2.5e-3,
+ "colors": 2.5e-3,
+ }
+
+ BS = batch_size * world_size
+ optimizers = {
+ name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)(
+ [
+ {
+ "params": splats[name],
+ "lr": default_lrs[name] * math.sqrt(BS),
+ "name": name,
+ }
+ ],
+ eps=1e-15 / math.sqrt(BS),
+ betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
+ )
+ for name in splats.keys()
+ if name in default_lrs
+ }
+
+ return splats, optimizers
+
+
+ def reinit_optimizers(self):
+ """Reinitialize optimizers after pruning Gaussians."""
+ cfg = self.cfg
+ BS = cfg.batch_size * self.world_size
+
+ # Recreate the optimizer dictionary with new parameters
+ new_optimizers = {}
+ for name, param in self.splats.items():
+ lr = {
+ "means": 1.6e-4 * self.scene_scale,
+ "scales": 5e-3,
+ "quats": 1e-3,
+ "opacities": 5e-2,
+ "sh0": 2.5e-3,
+ "shN": 2.5e-3 / 20,
+ "features": 2.5e-3,
+ "colors": 2.5e-3,
+ }[name]
+
+ betas = (
+ 1 - BS * (1 - 0.9),
+ 1 - BS * (1 - 0.999),
+ )
+
+ new_optimizers[name] = torch.optim.Adam(
+ [{"params": param, "lr": lr * math.sqrt(BS), "name": name}],
+ eps=1e-15 / math.sqrt(BS),
+ betas=betas,
+ )
+
+ # Replace old optimizers with new ones
+ self.optimizers = new_optimizers
+
+
+ def prune_gaussians(self, prune_ratio: float, scores: torch.Tensor):
+ """Prune Gaussians based on score thresholding."""
+ num_prune = int(len(scores) * prune_ratio)
+ _, idx = torch.topk(scores.squeeze(), k=num_prune, largest=False)
+ mask = torch.ones_like(scores, dtype=torch.bool)
+ mask[idx] = False
+
+ remove(
+ params=self.splats,
+ optimizers= self.optimizers,
+ state= self.strategy_state,
+ mask = ~mask,
+ )
+
+
+
+
+ @torch.enable_grad()
+ def score_func(
+ self,
+ viewpoint_cam: Dict[str, torch.Tensor],
+ scores: torch.Tensor,
+ mask_views: torch.Tensor
+ ) -> None:
+
+ # Get camera matrices without extra dimensions
+ camtoworld = viewpoint_cam["camtoworld"].to(self.device) # shape: [4, 4]
+ K = viewpoint_cam["K"].to(self.device) # shape: [3, 3]
+ height, width = viewpoint_cam["image"].shape[1:3]
+
+ # Forward pass
+ colors, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworld, # Add batch dim here only
+ Ks=K, # Add batch dim here only
+ width=width,
+ height=height,
+ sh_degree=self.cfg.sh_degree,
+ image_ids=viewpoint_cam["image_id"].to(self.device)[None],
+ )
+
+ # Compute loss
+ gt_image = viewpoint_cam["image"].to(self.device) / 255.0
+ l1loss = F.l1_loss(colors, gt_image)
+ ssimloss = 1.0 - fused_ssim(
+ colors.permute(0, 3, 1, 2),
+ gt_image.permute(0, 3, 1, 2),
+ padding="valid"
+ )
+ loss = l1loss * (1.0 - self.cfg.ssim_lambda) + ssimloss * self.cfg.ssim_lambda
+
+ # Backward pass
+ loss.backward()
+
+
+
+ # Opacity gradient
+ opacity_grad = self.splats["opacities"].grad.abs().squeeze() # [N]
+
+ # Means gradient - reduce across channel dim (assumes shape [N, 3])
+ means_grad = self.splats["means"].grad.abs().mean(dim=1).squeeze() # [N]
+
+ # Scales gradient
+ scales_grad = self.splats["scales"].grad.abs().mean(dim=1).squeeze() # [N]
+
+ # SH0 gradient - make sure you are reducing the right dimension
+ sh0_grad = self.splats["sh0"].grad.abs().view(self.splats["sh0"].grad.shape[0], -1).mean(dim=1).squeeze()
+
+ # SHN gradient - often [N, K, 3], so reduce last two dims
+ shN_grad = self.splats["shN"].grad.abs().view(self.splats["shN"].grad.shape[0], -1).mean(dim=1).squeeze() # [N]
+
+
+ # Combine all scores
+ combined = opacity_grad + means_grad + scales_grad + sh0_grad + shN_grad
+
+ # Thresholding
+ mask_views += combined > 50 * 1e-8
+
+ # Accumulate for scoring
+ with torch.no_grad():
+ scores += combined
+
+
+ @torch.no_grad()
+ def prune(self, prune_ratio: float):
+ print("Running pruning...")
+ scores = torch.zeros_like(self.splats["opacities"])
+ mask_views = torch.zeros_like(self.splats["opacities"])
+
+ trainloader = torch.utils.data.DataLoader(
+ self.trainset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+ pbar = tqdm.tqdm(trainloader, desc="Computing pruning scores")
+ for data in pbar:
+ self.score_func(data, scores, mask_views)
+ pbar.update(1)
+
+ np.savetxt('mask_views_full.txt', mask_views.cpu().numpy(), fmt='%.18e')
+ np.savetxt('scores_txt.txt', scores.cpu().numpy(), fmt='%.18e')
+
+ # Prune Gaussians
+ self.prune_gaussians(prune_ratio, scores)
+
+
+
+ def rasterize_splats(
+ self,
+ camtoworlds: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ masks: Optional[Tensor] = None,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Dict]:
+ means = self.splats["means"] # [N, 3]
+ quats = self.splats["quats"] # [N, 4]
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
+
+ image_ids = kwargs.pop("image_ids", None)
+ if self.cfg.app_opt:
+ colors = self.app_module(
+ features=self.splats["features"],
+ embed_ids=image_ids,
+ dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
+ sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree),
+ )
+ colors = colors + self.splats["colors"]
+ colors = torch.sigmoid(colors)
+ else:
+ colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3]
+
+ rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
+ render_colors, render_alphas, info = rasterization(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.cfg.packed,
+ absgrad=(
+ self.cfg.strategy.absgrad
+ if isinstance(self.cfg.strategy, DefaultStrategy)
+ else False
+ ),
+ sparse_grad=self.cfg.sparse_grad,
+ rasterize_mode=rasterize_mode,
+ distributed=self.world_size > 1,
+ camera_model=self.cfg.camera_model,
+ **kwargs,
+ )
+ if masks is not None:
+ render_colors[~masks] = 0
+ return render_colors, render_alphas, info
+
+
+ @torch.no_grad()
+ def eval(self, step: int, stage: str = "val"):
+ """Entry for evaluation."""
+ print("Running evaluation...")
+ cfg = self.cfg
+ device = self.device
+ world_rank = self.world_rank
+
+ valloader = torch.utils.data.DataLoader(
+ self.valset, batch_size=1, shuffle=False, num_workers=1
+ )
+ ellipse_time = 0
+ metrics = defaultdict(list)
+ for i, data in enumerate(valloader):
+ camtoworlds = data["camtoworld"].to(device)
+ Ks = data["K"].to(device)
+ pixels = data["image"].to(device) / 255.0
+ masks = data["mask"].to(device) if "mask" in data else None
+ height, width = pixels.shape[1:3]
+
+ torch.cuda.synchronize()
+ tic = time.time()
+ colors, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ masks=masks,
+ ) # [1, H, W, 3]
+ torch.cuda.synchronize()
+ ellipse_time += time.time() - tic
+
+ colors = torch.clamp(colors, 0.0, 1.0)
+ canvas_list = [pixels, colors]
+
+ if world_rank == 0:
+ # write images
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
+ canvas = (canvas * 255).astype(np.uint8)
+ imageio.imwrite(
+ f"{self.render_dir}/{stage}_step{step}_{i:04d}.png",
+ canvas,
+ )
+
+ pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
+ colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W]
+ metrics["psnr"].append(self.psnr(colors_p, pixels_p))
+ metrics["ssim"].append(self.ssim(colors_p, pixels_p))
+ metrics["lpips"].append(self.lpips(colors_p, pixels_p))
+ if cfg.use_bilateral_grid:
+ cc_colors = color_correct(colors, pixels)
+ cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W]
+ metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p))
+
+ if world_rank == 0:
+ ellipse_time /= len(valloader)
+
+ stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()}
+ stats.update(
+ {
+ "ellipse_time": ellipse_time,
+ "num_GS": len(self.splats["means"]),
+ }
+ )
+ print(
+ f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} "
+ f"Time: {stats['ellipse_time']:.3f}s/image "
+ f"Number of GS: {stats['num_GS']}"
+ )
+
+
+
+ @torch.no_grad()
+ def render_traj(self, step: int):
+ """Entry for trajectory rendering."""
+ print("Running trajectory rendering...")
+ cfg = self.cfg
+ device = self.device
+
+ camtoworlds_all = self.parser.camtoworlds[5:-5]
+ if cfg.render_traj_path == "interp":
+ camtoworlds_all = generate_interpolated_path(
+ camtoworlds_all, 1
+ ) # [N, 3, 4]
+ elif cfg.render_traj_path == "ellipse":
+ height = camtoworlds_all[:, 2, 3].mean()
+ camtoworlds_all = generate_ellipse_path_z(
+ camtoworlds_all, height=height
+ ) # [N, 3, 4]
+ elif cfg.render_traj_path == "spiral":
+ camtoworlds_all = generate_spiral_path(
+ camtoworlds_all,
+ bounds=self.parser.bounds * self.scene_scale,
+ spiral_scale_r=self.parser.extconf["spiral_radius_scale"],
+ )
+ else:
+ raise ValueError(
+ f"Render trajectory type not supported: {cfg.render_traj_path}"
+ )
+
+ camtoworlds_all = np.concatenate(
+ [
+ camtoworlds_all,
+ np.repeat(
+ np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0
+ ),
+ ],
+ axis=1,
+ ) # [N, 4, 4]
+
+ camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
+ K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
+ width, height = list(self.parser.imsize_dict.values())[0]
+
+ # save to video
+ video_dir = f"{cfg.result_dir}/videos"
+ os.makedirs(video_dir, exist_ok=True)
+ writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
+ for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
+ camtoworlds = camtoworlds_all[i : i + 1]
+ Ks = K[None]
+
+ renders, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ render_mode="RGB+ED",
+ ) # [1, H, W, 4]
+ colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3]
+ depths = renders[..., 3:4] # [1, H, W, 1]
+ depths = (depths - depths.min()) / (depths.max() - depths.min())
+ canvas_list = [colors, depths.repeat(1, 1, 1, 3)]
+
+ # write images
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
+ canvas = (canvas * 255).astype(np.uint8)
+ writer.append_data(canvas)
+ writer.close()
+ print(f"Video saved to {video_dir}/traj_{step}.mp4")
+
+ @torch.no_grad()
+ def run_compression(self, step: int):
+ """Entry for running compression."""
+ print("Running compression...")
+ world_rank = self.world_rank
+
+ compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}"
+ os.makedirs(compress_dir, exist_ok=True)
+
+ self.compression_method.compress(compress_dir, self.splats)
+
+ # evaluate compression
+ splats_c = self.compression_method.decompress(compress_dir)
+ for k in splats_c.keys():
+ self.splats[k].data = splats_c[k].to(self.device)
+ self.eval(step=step, stage="compress")
+
+ @torch.no_grad()
+ def _viewer_render_fn(
+ self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ ):
+ """Callable function for the viewer."""
+ W, H = img_wh
+ c2w = camera_state.c2w
+ K = camera_state.get_K(img_wh)
+ c2w = torch.from_numpy(c2w).float().to(self.device)
+ K = torch.from_numpy(K).float().to(self.device)
+
+ render_colors, _, _ = self.rasterize_splats(
+ camtoworlds=c2w[None],
+ Ks=K[None],
+ width=W,
+ height=H,
+ sh_degree=self.cfg.sh_degree, # active all SH degrees
+ radius_clip=3.0, # skip GSs that have small image radius (in pixels)
+ ) # [1, H, W, 3]
+ return render_colors[0].cpu().numpy()
+
+
+def main(local_rank: int, world_rank, world_size: int, cfg: Config):
+ if world_size > 1 and not cfg.disable_viewer:
+ cfg.disable_viewer = True
+ if world_rank == 0:
+ print("Viewer is disabled in distributed training.")
+
+ if cfg.ckpt is not None:
+ # run eval only
+ runner = Runner(local_rank, world_rank, world_size, cfg)
+
+ step = runner.steps
+ if cfg.eval_only:
+ runner.eval(step=step)
+ else:
+ print("Hard pruning in progress...")
+ if cfg.pruning_ratio != 0:
+ runner.prune(prune_ratio=cfg.pruning_ratio)
+ print("The size of gaussian is:", runner.splats["means"].shape)
+
+ # save checkpoint after hard pruning
+ name = os.path.splitext(os.path.basename(cfg.ckpt[0]))[0]
+
+ data = {"step": step, "splats": runner.splats.state_dict()}
+ if cfg.pose_opt:
+ if world_size > 1:
+ data["pose_adjust"] = runner.pose_adjust.module.state_dict()
+ else:
+ data["pose_adjust"] = runner.pose_adjust.state_dict()
+ if cfg.app_opt:
+ if world_size > 1:
+ data["app_module"] = runner.app_module.module.state_dict()
+ else:
+ data["app_module"] = runner.app_module.state_dict()
+ suffix = str(cfg.pruning_ratio).replace("0.", "")
+
+ print("output format", cfg.output_format)
+
+
+ save_splats(f"{runner.ckpt_dir}/{name}_pruned_test{suffix}", data, cfg.output_format)
+
+ else:
+ raise ValueError("ckpt cant be None")
+
+ if not cfg.disable_viewer:
+ print("Viewer running... Ctrl+C to exit.")
+ time.sleep(1000000)
+
+
+if __name__ == "__main__":
+ # Config objects we can choose between.
+ # Each is a tuple of (CLI description, config object).
+ configs = {
+ "default": (
+ "Gaussian splatting training using densification heuristics from the original paper.",
+ Config(
+ strategy=DefaultStrategy(verbose=True),
+ ),
+ ),
+ "mcmc": (
+ "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
+ Config(
+ init_opa=0.5,
+ init_scale=0.1,
+ opacity_reg=0.01,
+ scale_reg=0.01,
+ strategy=MCMCStrategy(verbose=True),
+ ),
+ ),
+ }
+ cfg = tyro.extras.overridable_config_cli(configs)
+ cfg.adjust_steps(cfg.steps_scaler)
+
+ # try import extra dependencies
+ if cfg.compression == "png":
+ try:
+ import plas
+ import torchpq
+ except:
+ raise ImportError(
+ "To use PNG compression, you need to install "
+ "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) "
+ "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') "
+ )
+
+ cli(main, cfg, verbose=True)
\ No newline at end of file
diff --git a/depth_hard_pruner.py b/depth_hard_pruner.py
new file mode 100644
index 0000000..bd3cb14
--- /dev/null
+++ b/depth_hard_pruner.py
@@ -0,0 +1,853 @@
+"""
+Depth hard pruning Script
+
+Part of this code is based on gsplatβs `simple_trainer.py`:
+https://github.com/nerfstudio-project/gsplat/blob/main/examples/simple_trainer.py
+
+"""
+
+
+import math
+import os
+import time
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple, Union
+
+import imageio
+import nerfview
+import numpy as np
+import torch
+import torch.nn.functional as F
+import tqdm
+import tyro
+import viser
+
+from pruning_utils.nerf import NerfDataset , NerfParser
+from pruning_utils.colmap import Dataset, Parser
+from pruning_utils.traj import (
+ generate_interpolated_path,
+ generate_ellipse_path_z,
+ generate_spiral_path,
+)
+from torch import Tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
+from fused_ssim import fused_ssim
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from typing_extensions import Literal, assert_never
+from pruning_utils.utils import AppearanceOptModule, CameraOptModule, set_random_seed
+from pruning_utils.lib_bilagrid import (
+ BilateralGrid,
+)
+from pruning_utils.open_ply_pipeline import load_splats, save_splats
+
+
+from gsplat.compression import PngCompression
+from gsplat.distributed import cli
+from gsplat.rendering import rasterization
+from gsplat.strategy import DefaultStrategy, MCMCStrategy
+from gsplat.strategy.ops import remove
+
+from PIL import Image
+
+
+@dataclass
+class Config:
+ # Disable viewer
+ disable_viewer: bool = False
+ # Path to the .pt files. If provide, it will skip training and run evaluation only.
+ ckpt: Optional[List[str]] = None
+ # Name of compression strategy to use
+ compression: Optional[Literal["png"]] = None
+ # Render trajectory path
+ render_traj_path: str = "interp"
+
+ # Pruning Ratio
+ pruning_ratio: float = 0.0
+ # Output data format converted
+ output_format: str = "ply"
+ # Path to the Mip-NeRF 360 dataset
+ data_dir: str = "data"
+ # Downsample factor for the dataset
+ data_factor: int = 4
+ # Directory to save results
+ result_dir: str = "./results"
+ # Every N images there is a test image
+ test_every: int = 8
+ # Random crop size for training (experimental)
+ patch_size: Optional[int] = None
+ # A global scaler that applies to the scene size related parameters
+ global_scale: float = 1.0
+ # Normalize the world space
+ normalize_world_space: bool = False
+ # Camera model
+ camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole"
+
+ # Port for the viewer server
+ port: int = 8080
+
+ # Batch size for training. Learning rates are scaled automatically
+ batch_size: int = 1
+ # A global factor to scale the number of training steps
+ steps_scaler: float = 1.0
+
+ # Number of training steps
+ max_steps: int = 10
+ # Steps to evaluate the model
+ eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+ # Steps to save the model
+ save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+
+ # Initialization strategy
+ init_type: str = "sfm"
+ # Initial number of GSs. Ignored if using sfm
+ init_num_pts: int = 100_000
+ # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
+ init_extent: float = 3.0
+ # Degree of spherical harmonics
+ sh_degree: int = 3
+ # Turn on another SH degree every this steps
+ sh_degree_interval: int = 1000
+ # Initial opacity of GS
+ init_opa: float = 0.1
+ # Initial scale of GS
+ init_scale: float = 1.0
+ # Weight for SSIM loss
+ ssim_lambda: float = 0.2
+
+ # Near plane clipping distance
+ near_plane: float = 0.01
+ # Far plane clipping distance
+ far_plane: float = 1e10
+
+ # Strategy for GS densification
+ strategy: Union[DefaultStrategy, MCMCStrategy] = field(
+ default_factory=DefaultStrategy
+ )
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
+ packed: bool = False
+ # Use sparse gradients for optimization. (experimental)
+ sparse_grad: bool = False
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
+ antialiased: bool = False
+
+ # Opacity regularization
+ opacity_reg: float = 0.0
+ # Scale regularization
+ scale_reg: float = 0.0
+
+ # Enable camera optimization.
+ pose_opt: bool = False
+ # Learning rate for camera optimization
+ pose_opt_lr: float = 1e-5
+ # Regularization for camera optimization as weight decay
+ pose_opt_reg: float = 1e-6
+ # Add noise to camera extrinsics. This is only to test the camera pose optimization.
+ pose_noise: float = 0.0
+
+ # Enable appearance optimization. (experimental)
+ app_opt: bool = False
+ # Appearance embedding dimension
+ app_embed_dim: int = 16
+ # Learning rate for appearance optimization
+ app_opt_lr: float = 1e-3
+ # Regularization for appearance optimization as weight decay
+ app_opt_reg: float = 1e-6
+
+ # Enable bilateral grid. (experimental)
+ use_bilateral_grid: bool = False
+ # Shape of the bilateral grid (X, Y, W)
+ bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8)
+
+ # Enable depth loss. (experimental)
+ depth_loss: bool = True
+
+ lpips_net: Literal["vgg", "alex"] = "alex"
+
+ def adjust_steps(self, factor: float):
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
+ self.save_steps = [int(i * factor) for i in self.save_steps]
+ self.max_steps = int(self.max_steps * factor)
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
+
+ strategy = self.strategy
+ if isinstance(strategy, DefaultStrategy):
+ strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.reset_every = int(strategy.reset_every * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ elif isinstance(strategy, MCMCStrategy):
+ strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
+ strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
+ strategy.refine_every = int(strategy.refine_every * factor)
+ else:
+ assert_never(strategy)
+
+
+
+
+class Runner:
+ """Engine for training and testing."""
+
+ def __init__(
+ self, local_rank: int, world_rank: int, world_size: int, cfg: Config
+ ) -> None:
+ set_random_seed(42 + local_rank)
+
+ self.cfg = cfg
+ self.world_rank = world_rank
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.device = f"cuda:{local_rank}"
+
+ # Where to dump results.
+ os.makedirs(cfg.result_dir, exist_ok=True)
+
+ # Setup output directories.
+ self.ckpt_dir = f"{cfg.result_dir}"
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+ self.stats_dir = f"{cfg.result_dir}/stats"
+ os.makedirs(self.stats_dir, exist_ok=True)
+ self.render_dir = f"{cfg.result_dir}/renders"
+ os.makedirs(self.render_dir, exist_ok=True)
+
+ # Tensorboard
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
+
+ # Load data: Training data should contain initial points and colors.
+ ext = os.path.splitext(cfg.ckpt[0])[1].lower()
+ if ext == ".ckpt" or ext == ".ply" :
+ print("this is a ckpt or ply file ")
+ self.parser = NerfParser(
+ data_dir=cfg.data_dir,
+ factor=cfg.data_factor,
+ normalize=cfg.normalize_world_space,
+ test_every=cfg.test_every,
+ )
+ self.trainset = NerfDataset(
+ self.parser,
+ split="train",
+ patch_size=cfg.patch_size,
+ load_depths=cfg.depth_loss,
+ )
+ elif ext == ".pt":
+ print("this is a pt file")
+ self.parser = Parser(
+ data_dir=cfg.data_dir,
+ factor=cfg.data_factor,
+ normalize=cfg.normalize_world_space,
+ test_every=cfg.test_every,
+ )
+ self.trainset = Dataset(
+ self.parser,
+ split="train",
+ patch_size=cfg.patch_size,
+ load_depths=cfg.depth_loss,
+ )
+ else:
+ msg ="Invalid Model type. Use .pt, .ckpt or .ply files."
+ raise TypeError(msg)
+
+ self.valset = Dataset(self.parser, split="val")
+ self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale
+
+ # Model
+ feature_dim = 32 if cfg.app_opt else None
+ self.splats, self.optimizers = self.create_splats_with_optimizers(
+ scene_scale=self.scene_scale,
+ sparse_grad=cfg.sparse_grad,
+ batch_size=cfg.batch_size,
+ device=self.device,
+ world_size=world_size,
+ cfg=self.cfg,
+ )
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
+
+ # Densification Strategy
+ self.cfg.strategy.check_sanity(self.splats, self.optimizers)
+
+ if isinstance(self.cfg.strategy, DefaultStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state(
+ scene_scale=self.scene_scale
+ )
+ elif isinstance(self.cfg.strategy, MCMCStrategy):
+ self.strategy_state = self.cfg.strategy.initialize_state()
+ else:
+ assert_never(self.cfg.strategy)
+
+ # Compression Strategy
+ self.compression_method = None
+ if cfg.compression is not None:
+ if cfg.compression == "png":
+ self.compression_method = PngCompression()
+ else:
+ raise ValueError(f"Unknown compression strategy: {cfg.compression}")
+
+ self.pose_optimizers = []
+ if cfg.pose_opt:
+ self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_adjust.zero_init()
+ self.pose_optimizers = [
+ torch.optim.Adam(
+ self.pose_adjust.parameters(),
+ lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size),
+ weight_decay=cfg.pose_opt_reg,
+ )
+ ]
+ if world_size > 1:
+ self.pose_adjust = DDP(self.pose_adjust)
+
+ if cfg.pose_noise > 0.0:
+ self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_perturb.random_init(cfg.pose_noise)
+ if world_size > 1:
+ self.pose_perturb = DDP(self.pose_perturb)
+
+ self.app_optimizers = []
+ if cfg.app_opt:
+ assert feature_dim is not None
+ self.app_module = AppearanceOptModule(
+ len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree
+ ).to(self.device)
+ # initialize the last layer to be zero so that the initial output is zero.
+ torch.nn.init.zeros_(self.app_module.color_head[-1].weight)
+ torch.nn.init.zeros_(self.app_module.color_head[-1].bias)
+ self.app_optimizers = [
+ torch.optim.Adam(
+ self.app_module.embeds.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0,
+ weight_decay=cfg.app_opt_reg,
+ ),
+ torch.optim.Adam(
+ self.app_module.color_head.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size),
+ ),
+ ]
+ if world_size > 1:
+ self.app_module = DDP(self.app_module)
+
+ self.bil_grid_optimizers = []
+ if cfg.use_bilateral_grid:
+ self.bil_grids = BilateralGrid(
+ len(self.trainset),
+ grid_X=cfg.bilateral_grid_shape[0],
+ grid_Y=cfg.bilateral_grid_shape[1],
+ grid_W=cfg.bilateral_grid_shape[2],
+ ).to(self.device)
+ self.bil_grid_optimizers = [
+ torch.optim.Adam(
+ self.bil_grids.parameters(),
+ lr=2e-3 * math.sqrt(cfg.batch_size),
+ eps=1e-15,
+ ),
+ ]
+
+ # Losses & Metrics.
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device)
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
+
+ if cfg.lpips_net == "alex":
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="alex", normalize=True
+ ).to(self.device)
+ elif cfg.lpips_net == "vgg":
+ # The 3DGS official repo uses lpips vgg, which is equivalent with the following:
+ self.lpips = LearnedPerceptualImagePatchSimilarity(
+ net_type="vgg", normalize=False
+ ).to(self.device)
+ else:
+ raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}")
+
+ # Viewer
+ if not self.cfg.disable_viewer:
+ self.server = viser.ViserServer(port=cfg.port, verbose=False)
+ self.viewer = nerfview.Viewer(
+ server=self.server,
+ render_fn=self._viewer_render_fn,
+ mode="training",
+ )
+
+ def create_splats_with_optimizers(
+ self,
+ scene_scale: float = 1.0,
+ sparse_grad: bool = False,
+ batch_size: int = 1,
+ device: str = "cuda",
+ world_size: int = 1,
+ cfg: Optional[List[str]] = None,
+ ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
+
+
+ self.steps, splats = load_splats(cfg.ckpt[0], device)
+
+ # Convert to ParameterDict on the correct device
+ splats = torch.nn.ParameterDict(splats).to(device)
+
+
+ # Learning rates: you need to define them since theyβre not stored in the ckpt
+ # Use default values from above
+ default_lrs = {
+ "means": 1.6e-4 * scene_scale,
+ "scales": 5e-3,
+ "quats": 1e-3,
+ "opacities": 5e-2,
+ "sh0": 2.5e-3,
+ "shN": 2.5e-3 / 20,
+ "features": 2.5e-3,
+ "colors": 2.5e-3,
+ }
+
+ BS = batch_size * world_size
+ optimizers = {
+ name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)(
+ [
+ {
+ "params": splats[name],
+ "lr": default_lrs[name] * math.sqrt(BS),
+ "name": name,
+ }
+ ],
+ eps=1e-15 / math.sqrt(BS),
+ betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)),
+ )
+ for name in splats.keys()
+ if name in default_lrs
+ }
+
+ return splats, optimizers
+
+
+ def reinit_optimizers(self):
+ """Reinitialize optimizers after pruning Gaussians."""
+ cfg = self.cfg
+ BS = cfg.batch_size * self.world_size
+
+ # Recreate the optimizer dictionary with new parameters
+ new_optimizers = {}
+ for name, param in self.splats.items():
+ lr = {
+ "means": 1.6e-4 * self.scene_scale,
+ "scales": 5e-3,
+ "quats": 1e-3,
+ "opacities": 5e-2,
+ "sh0": 2.5e-3,
+ "shN": 2.5e-3 / 20,
+ "features": 2.5e-3,
+ "colors": 2.5e-3,
+ }[name]
+
+ betas = (
+ 1 - BS * (1 - 0.9),
+ 1 - BS * (1 - 0.999),
+ )
+
+ new_optimizers[name] = torch.optim.Adam(
+ [{"params": param, "lr": lr * math.sqrt(BS), "name": name}],
+ eps=1e-15 / math.sqrt(BS),
+ betas=betas,
+ )
+
+ # Replace old optimizers with new ones
+ self.optimizers = new_optimizers
+
+
+ def prune_gaussians(self, prune_ratio: float, scores: torch.Tensor):
+ """Prune Gaussians based on score thresholding."""
+ reduced_scores = scores.mean(dim=1)
+
+ num_prune = int(len(reduced_scores) * prune_ratio)
+ _, idx = torch.topk(reduced_scores, k=num_prune, largest=False)
+ mask = torch.ones_like(reduced_scores, dtype=torch.bool)
+ mask[idx] = False
+
+ remove(
+ params=self.splats,
+ optimizers= self.optimizers,
+ state= self.strategy_state,
+ mask = ~mask,
+ )
+
+
+
+ def save_tensors_side_by_side(self, rendered_tensor, target_tensor, filename="depth_debug/depth_comparison.png"):
+ """
+ Saves two PyTorch tensors (e.g., depth maps) as side-by-side PNG images.
+ If the file exists, appends a number to avoid overwriting (e.g., _1, _2, ...).
+
+ Args:
+ rendered_tensor (torch.Tensor): Predicted tensor (e.g., rendered depth).
+ target_tensor (torch.Tensor): Ground truth tensor (e.g., image depth).
+ filename (str): Base output filename (will be modified if it exists).
+ """
+ def tensor_to_image(tensor):
+ # Ensure tensor is on CPU and detach gradients
+ tensor = tensor.cpu().detach()
+
+ # Remove batch dimension if present (e.g., [1, H, W, 1] -> [H, W, 1])
+ while tensor.dim() > 3 and tensor.shape[0] == 1:
+ tensor = tensor.squeeze(0)
+
+ # At this point, should be [H, W] or [H, W, 1]
+ if tensor.dim() == 3 and tensor.shape[-1] == 1:
+ tensor = tensor.squeeze(-1) # Remove channel dim
+
+ # Now it should be [H, W]
+ if tensor.dim() != 2:
+ raise ValueError(f"Unexpected tensor shape after squeezing: {tensor.shape}")
+
+ # Normalize to [0, 1]
+ tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
+
+ # Convert to uint8 [0, 255]
+ image = (tensor.numpy() * 255).astype(np.uint8)
+
+ return Image.fromarray(image, mode='L'), image # L = grayscale
+
+ # Convert both tensors to PIL images
+ rendered_img, rendered_image = tensor_to_image(rendered_tensor)
+ target_img, target_image = tensor_to_image(target_tensor)
+
+
+ # Create a new image that combines them horizontally
+ total_width = rendered_img.width + target_img.width
+ max_height = max(rendered_img.height, target_img.height)
+
+ combined_img = Image.new('L', (total_width, max_height)) # 'L' for grayscale
+ combined_img.paste(rendered_img, (0, 0))
+ combined_img.paste(target_img, (rendered_img.width, 0))
+
+ # Save to disk
+ combined_img.save(filename)
+
+ return rendered_image, target_image
+
+
+ @torch.enable_grad()
+ def score_func(
+ self,
+ viewpoint_cam: Dict[str, torch.Tensor],
+ scores: torch.Tensor,
+ mask_views: torch.Tensor
+ ) -> None:
+
+ # Get camera matrices without extra dimensions
+ camtoworld = viewpoint_cam["camtoworld"].to(self.device) # shape: [4, 4]
+ K = viewpoint_cam["K"].to(self.device) # shape: [3, 3]
+ height, width = viewpoint_cam["image"].shape[1:3]
+
+ # Forward pass
+ render, _, info = self.rasterize_splats(
+ camtoworlds=camtoworld, # Add batch dim here only
+ Ks=K, # Add batch dim here only
+ width=width,
+ height=height,
+ sh_degree=self.cfg.sh_degree,
+ image_ids=viewpoint_cam["image_id"].to(self.device)[None],
+ render_mode="RGB+ED"
+ )
+
+ # Compute loss
+ rendered_depth = render[..., 3:4]
+
+ image_depth = self.trainset[viewpoint_cam["image_id"].item()]["depth"].to(self.device)[None].unsqueeze(-1)
+
+ # rendered_image, target_image = self.save_tensors_side_by_side(rendered_np, image_np, f"depth_debug/00{self.trainset.indices[viewpoint_cam['image_id'].item()]}.png")
+
+ l1loss = F.l1_loss(rendered_depth, image_depth)
+ ssimloss = 1.0 - fused_ssim(
+ rendered_depth.permute(0, 3, 1, 2),
+ image_depth.permute(0, 3, 1, 2),
+ padding="valid"
+ )
+ loss = l1loss * (1.0 - self.cfg.ssim_lambda) + ssimloss * self.cfg.ssim_lambda
+
+ # Backward pass
+ loss.backward()
+
+ mask_views += self.splats["means"].grad.abs().squeeze() > 0.001
+
+ # Accumulate gradient magnitude as score (e.g., opacities)
+ with torch.no_grad():
+ scores += self.splats["means"].grad.abs().squeeze()
+
+
+ @torch.no_grad()
+ def prune(self, prune_ratio: float):
+ print("Running pruning...")
+ scores = torch.zeros_like(self.splats["means"])
+ mask_views = torch.zeros_like(self.splats["means"])
+
+ trainloader = torch.utils.data.DataLoader(
+ self.trainset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+ i = 0
+ pbar = tqdm.tqdm(trainloader, desc="Computing pruning scores")
+ for data in pbar:
+ self.score_func(data, scores, mask_views)
+ pbar.update(1)
+ i += 1
+
+ scores = scores / (mask_views + 1e-8)
+
+ self.prune_gaussians(prune_ratio, scores)
+
+
+
+ def rasterize_splats(
+ self,
+ camtoworlds: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ masks: Optional[Tensor] = None,
+ render_mode: Optional[str] = "RGB",
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Dict]:
+ means = self.splats["means"] # [N, 3]
+ quats = self.splats["quats"] # [N, 4]
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
+
+ image_ids = kwargs.pop("image_ids", None)
+ if self.cfg.app_opt:
+ colors = self.app_module(
+ features=self.splats["features"],
+ embed_ids=image_ids,
+ dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
+ sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree),
+ )
+ colors = colors + self.splats["colors"]
+ colors = torch.sigmoid(colors)
+ else:
+ colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3]
+
+ rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
+
+ render_colors, render_alphas, info = rasterization(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.cfg.packed,
+ absgrad=(
+ self.cfg.strategy.absgrad
+ if isinstance(self.cfg.strategy, DefaultStrategy)
+ else False
+ ),
+ sparse_grad=self.cfg.sparse_grad,
+ rasterize_mode=rasterize_mode,
+ distributed=self.world_size > 1,
+ camera_model=self.cfg.camera_model,
+ render_mode=render_mode,
+ **kwargs,
+ )
+ if masks is not None:
+ render_colors[~masks] = 0
+ return render_colors, render_alphas, info
+
+
+ @torch.no_grad()
+ def render_traj(self, step: int):
+ """Entry for trajectory rendering."""
+ print("Running trajectory rendering...")
+ cfg = self.cfg
+ device = self.device
+
+ camtoworlds_all = self.parser.camtoworlds[5:-5]
+ if cfg.render_traj_path == "interp":
+ camtoworlds_all = generate_interpolated_path(
+ camtoworlds_all, 1
+ ) # [N, 3, 4]
+ elif cfg.render_traj_path == "ellipse":
+ height = camtoworlds_all[:, 2, 3].mean()
+ camtoworlds_all = generate_ellipse_path_z(
+ camtoworlds_all, height=height
+ ) # [N, 3, 4]
+ elif cfg.render_traj_path == "spiral":
+ camtoworlds_all = generate_spiral_path(
+ camtoworlds_all,
+ bounds=self.parser.bounds * self.scene_scale,
+ spiral_scale_r=self.parser.extconf["spiral_radius_scale"],
+ )
+ else:
+ raise ValueError(
+ f"Render trajectory type not supported: {cfg.render_traj_path}"
+ )
+
+ camtoworlds_all = np.concatenate(
+ [
+ camtoworlds_all,
+ np.repeat(
+ np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0
+ ),
+ ],
+ axis=1,
+ ) # [N, 4, 4]
+
+ camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
+ K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
+ width, height = list(self.parser.imsize_dict.values())[0]
+
+ # save to video
+ video_dir = f"{cfg.result_dir}/videos"
+ os.makedirs(video_dir, exist_ok=True)
+ writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
+ for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
+ camtoworlds = camtoworlds_all[i : i + 1]
+ Ks = K[None]
+
+ renders, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ render_mode="RGB+D",
+ ) # [1, H, W, 4]
+ colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3]
+ depths = renders[..., 3:4] # [1, H, W, 1]
+ depths = (depths - depths.min()) / (depths.max() - depths.min())
+ canvas_list = [colors, depths.repeat(1, 1, 1, 3)]
+
+ # write images
+ canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
+ canvas = (canvas * 255).astype(np.uint8)
+ writer.append_data(canvas)
+ writer.close()
+ print(f"Video saved to {video_dir}/traj_{step}.mp4")
+
+ @torch.no_grad()
+ def run_compression(self, step: int):
+ """Entry for running compression."""
+ print("Running compression...")
+ world_rank = self.world_rank
+
+ compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}"
+ os.makedirs(compress_dir, exist_ok=True)
+
+ self.compression_method.compress(compress_dir, self.splats)
+
+ # evaluate compression
+ splats_c = self.compression_method.decompress(compress_dir)
+ for k in splats_c.keys():
+ self.splats[k].data = splats_c[k].to(self.device)
+ self.eval(step=step, stage="compress")
+
+ @torch.no_grad()
+ def _viewer_render_fn(
+ self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ ):
+ """Callable function for the viewer."""
+ W, H = img_wh
+ c2w = camera_state.c2w
+ K = camera_state.get_K(img_wh)
+ c2w = torch.from_numpy(c2w).float().to(self.device)
+ K = torch.from_numpy(K).float().to(self.device)
+
+ render_colors, _, _ = self.rasterize_splats(
+ camtoworlds=c2w[None],
+ Ks=K[None],
+ width=W,
+ height=H,
+ sh_degree=self.cfg.sh_degree, # active all SH degrees
+ radius_clip=3.0, # skip GSs that have small image radius (in pixels)
+ ) # [1, H, W, 3]
+ return render_colors[0].cpu().numpy()
+
+
+def main(local_rank: int, world_rank, world_size: int, cfg: Config):
+ if world_size > 1 and not cfg.disable_viewer:
+ cfg.disable_viewer = True
+ if world_rank == 0:
+ print("Viewer is disabled in distributed training.")
+
+ if cfg.ckpt is not None:
+ # run eval only
+ runner = Runner(local_rank, world_rank, world_size, cfg)
+
+
+ if cfg.pruning_ratio != 0:
+ print("Hard pruning in progress...")
+ runner.prune(prune_ratio=cfg.pruning_ratio)
+ print("The size of gaussian is:", runner.splats["means"].shape)
+
+ # save checkpoint after hard pruning
+ step = runner.steps
+
+ name = os.path.splitext(os.path.basename(cfg.ckpt[0]))[0]
+
+ data = {"step": step, "splats": runner.splats.state_dict()}
+ if cfg.pose_opt:
+ if world_size > 1:
+ data["pose_adjust"] = runner.pose_adjust.module.state_dict()
+ else:
+ data["pose_adjust"] = runner.pose_adjust.state_dict()
+ if cfg.app_opt:
+ if world_size > 1:
+ data["app_module"] = runner.app_module.module.state_dict()
+ else:
+ data["app_module"] = runner.app_module.state_dict()
+ suffix = str(cfg.pruning_ratio).replace("0.", "")
+
+ print("output format", cfg.output_format)
+
+ save_splats(f"{runner.ckpt_dir}/{name}_depth_pruned_{suffix}", data, cfg.output_format)
+
+ else:
+ raise ValueError("ckpt cant be None")
+
+ if not cfg.disable_viewer:
+ print("Viewer running... Ctrl+C to exit.")
+ time.sleep(1000000)
+
+
+if __name__ == "__main__":
+ # Config objects we can choose between.
+ # Each is a tuple of (CLI description, config object).
+ configs = {
+ "default": (
+ "Gaussian splatting training using densification heuristics from the original paper.",
+ Config(
+ strategy=DefaultStrategy(verbose=True),
+ ),
+ ),
+ "mcmc": (
+ "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.",
+ Config(
+ init_opa=0.5,
+ init_scale=0.1,
+ opacity_reg=0.01,
+ scale_reg=0.01,
+ strategy=MCMCStrategy(verbose=True),
+ ),
+ ),
+ }
+ cfg = tyro.extras.overridable_config_cli(configs)
+ cfg.adjust_steps(cfg.steps_scaler)
+
+ # try import extra dependencies
+ if cfg.compression == "png":
+ try:
+ import plas
+ import torchpq
+ except:
+ raise ImportError(
+ "To use PNG compression, you need to install "
+ "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) "
+ "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') "
+ )
+
+ cli(main, cfg, verbose=True)
\ No newline at end of file
diff --git a/pruning_utils/colmap.py b/pruning_utils/colmap.py
new file mode 100644
index 0000000..9bf0ccb
--- /dev/null
+++ b/pruning_utils/colmap.py
@@ -0,0 +1,425 @@
+"""
+===========================================================================
+Unmodified code from gsplat examples
+Original source:
+https://github.com/nerfstudio-project/gsplat/blob/main/examples/datasets/colmap.py :contentReference[oaicite:0]{index=0}
+License: Apache License 2.0
+----------------------------------------------------------------------------
+This file was copied verbatim from the gsplat repository. No modifications were made.
+===========================================================================
+"""
+
+import os
+import json
+from typing import Any, Dict, List, Optional
+from typing_extensions import assert_never
+
+import cv2
+import imageio.v2 as imageio
+import numpy as np
+import torch
+from pycolmap import SceneManager
+
+from .normalize import (
+ align_principle_axes,
+ similarity_from_cameras,
+ transform_cameras,
+ transform_points,
+)
+
+
+def _get_rel_paths(path_dir: str) -> List[str]:
+ """Recursively get relative paths of files in a directory."""
+ paths = []
+ for dp, dn, fn in os.walk(path_dir):
+ for f in fn:
+ paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
+ return paths
+
+
+class Parser:
+ """COLMAP parser."""
+
+ def __init__(
+ self,
+ data_dir: str,
+ factor: int = 1,
+ normalize: bool = False,
+ test_every: int = 8,
+ ):
+ self.data_dir = data_dir
+ self.factor = factor
+ self.normalize = normalize
+ self.test_every = test_every
+
+ colmap_dir = os.path.join(data_dir, "sparse/0/")
+ if not os.path.exists(colmap_dir):
+ colmap_dir = os.path.join(data_dir, "sparse")
+ assert os.path.exists(
+ colmap_dir
+ ), f"COLMAP directory {colmap_dir} does not exist."
+
+ manager = SceneManager(colmap_dir)
+ manager.load_cameras()
+ manager.load_images()
+ manager.load_points3D()
+
+ # Extract extrinsic matrices in world-to-camera format.
+ imdata = manager.images
+ w2c_mats = []
+ camera_ids = []
+ Ks_dict = dict()
+ params_dict = dict()
+ imsize_dict = dict() # width, height
+ mask_dict = dict()
+ bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
+ for k in imdata:
+ im = imdata[k]
+ rot = im.R()
+ trans = im.tvec.reshape(3, 1)
+ w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
+ w2c_mats.append(w2c)
+
+ # support different camera intrinsics
+ camera_id = im.camera_id
+ camera_ids.append(camera_id)
+
+ # camera intrinsics
+ cam = manager.cameras[camera_id]
+ fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+ K[:2, :] /= factor
+ Ks_dict[camera_id] = K
+
+ # Get distortion parameters.
+ type_ = cam.camera_type
+ if type_ == 0 or type_ == "SIMPLE_PINHOLE":
+ params = np.empty(0, dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 1 or type_ == "PINHOLE":
+ params = np.empty(0, dtype=np.float32)
+ camtype = "perspective"
+ if type_ == 2 or type_ == "SIMPLE_RADIAL":
+ params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 3 or type_ == "RADIAL":
+ params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 4 or type_ == "OPENCV":
+ params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
+ camtype = "perspective"
+ elif type_ == 5 or type_ == "OPENCV_FISHEYE":
+ params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
+ camtype = "fisheye"
+ assert (
+ camtype == "perspective" or camtype == "fisheye"
+ ), f"Only perspective and fisheye cameras are supported, got {type_}"
+
+ params_dict[camera_id] = params
+ imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
+ mask_dict[camera_id] = None
+ print(
+ f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
+ )
+
+ if len(imdata) == 0:
+ raise ValueError("No images found in COLMAP.")
+ if not (type_ == 0 or type_ == 1):
+ print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
+
+ w2c_mats = np.stack(w2c_mats, axis=0)
+
+ # Convert extrinsics to camera-to-world.
+ camtoworlds = np.linalg.inv(w2c_mats)
+
+ # Image names from COLMAP. No need for permuting the poses according to
+ # image names anymore.
+ image_names = [imdata[k].name for k in imdata]
+
+ # Previous Nerf results were generated with images sorted by filename,
+ # ensure metrics are reported on the same test set.
+ inds = np.argsort(image_names)
+ image_names = [image_names[i] for i in inds]
+ camtoworlds = camtoworlds[inds]
+ camera_ids = [camera_ids[i] for i in inds]
+
+ # Load extended metadata. Used by Bilarf dataset.
+ self.extconf = {
+ "spiral_radius_scale": 1.0,
+ "no_factor_suffix": False,
+ }
+ extconf_file = os.path.join(data_dir, "ext_metadata.json")
+ if os.path.exists(extconf_file):
+ with open(extconf_file) as f:
+ self.extconf.update(json.load(f))
+
+ # Load bounds if possible (only used in forward facing scenes).
+ self.bounds = np.array([0.01, 1.0])
+ posefile = os.path.join(data_dir, "poses_bounds.npy")
+ if os.path.exists(posefile):
+ self.bounds = np.load(posefile)[:, -2:]
+
+ # Load images.
+ if factor > 1 and not self.extconf["no_factor_suffix"]:
+ image_dir_suffix = f"_{factor}"
+ else:
+ image_dir_suffix = ""
+ colmap_image_dir = os.path.join(data_dir, "images")
+ image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
+ for d in [image_dir, colmap_image_dir]:
+ if not os.path.exists(d):
+ raise ValueError(f"Image folder {d} does not exist.")
+
+ # Downsampled images may have different names vs images used for COLMAP,
+ # so we need to map between the two sorted lists of files.
+ colmap_files = sorted(_get_rel_paths(colmap_image_dir))
+ image_files = sorted(_get_rel_paths(image_dir))
+ colmap_to_image = dict(zip(colmap_files, image_files))
+ image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
+
+ # 3D points and {image_name -> [point_idx]}
+ points = manager.points3D.astype(np.float32)
+ points_err = manager.point3D_errors.astype(np.float32)
+ points_rgb = manager.point3D_colors.astype(np.uint8)
+ point_indices = dict()
+
+ image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
+ for point_id, data in manager.point3D_id_to_images.items():
+ for image_id, _ in data:
+ image_name = image_id_to_name[image_id]
+ point_idx = manager.point3D_id_to_point3D_idx[point_id]
+ point_indices.setdefault(image_name, []).append(point_idx)
+ point_indices = {
+ k: np.array(v).astype(np.int32) for k, v in point_indices.items()
+ }
+
+ # Normalize the world space.
+ if normalize:
+ T1 = similarity_from_cameras(camtoworlds)
+ camtoworlds = transform_cameras(T1, camtoworlds)
+ points = transform_points(T1, points)
+
+ T2 = align_principle_axes(points)
+ camtoworlds = transform_cameras(T2, camtoworlds)
+ points = transform_points(T2, points)
+
+ transform = T2 @ T1
+ else:
+ transform = np.eye(4)
+
+ self.image_names = image_names # List[str], (num_images,)
+ self.image_paths = image_paths # List[str], (num_images,)
+ self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
+ self.camera_ids = camera_ids # List[int], (num_images,)
+ self.Ks_dict = Ks_dict # Dict of camera_id -> K
+ self.params_dict = params_dict # Dict of camera_id -> params
+ self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
+ self.mask_dict = mask_dict # Dict of camera_id -> mask
+ self.points = points # np.ndarray, (num_points, 3)
+ self.points_err = points_err # np.ndarray, (num_points,)
+ self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
+ self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
+ self.transform = transform # np.ndarray, (4, 4)
+
+ # load one image to check the size. In the case of tanksandtemples dataset, the
+ # intrinsics stored in COLMAP corresponds to 2x upsampled images.
+ actual_image = imageio.imread(self.image_paths[0])[..., :3]
+ actual_height, actual_width = actual_image.shape[:2]
+ colmap_width, colmap_height = self.imsize_dict[self.camera_ids[0]]
+ s_height, s_width = actual_height / colmap_height, actual_width / colmap_width
+ for camera_id, K in self.Ks_dict.items():
+ K[0, :] *= s_width
+ K[1, :] *= s_height
+ self.Ks_dict[camera_id] = K
+ width, height = self.imsize_dict[camera_id]
+ self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height))
+
+ # undistortion
+ self.mapx_dict = dict()
+ self.mapy_dict = dict()
+ self.roi_undist_dict = dict()
+ for camera_id in self.params_dict.keys():
+ params = self.params_dict[camera_id]
+ if len(params) == 0:
+ continue # no distortion
+ assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
+ assert (
+ camera_id in self.params_dict
+ ), f"Missing params for camera {camera_id}"
+ K = self.Ks_dict[camera_id]
+ width, height = self.imsize_dict[camera_id]
+
+ if camtype == "perspective":
+ K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
+ K, params, (width, height), 0
+ )
+ mapx, mapy = cv2.initUndistortRectifyMap(
+ K, params, None, K_undist, (width, height), cv2.CV_32FC1
+ )
+ mask = None
+ elif camtype == "fisheye":
+ fx = K[0, 0]
+ fy = K[1, 1]
+ cx = K[0, 2]
+ cy = K[1, 2]
+ grid_x, grid_y = np.meshgrid(
+ np.arange(width, dtype=np.float32),
+ np.arange(height, dtype=np.float32),
+ indexing="xy",
+ )
+ x1 = (grid_x - cx) / fx
+ y1 = (grid_y - cy) / fy
+ theta = np.sqrt(x1**2 + y1**2)
+ r = (
+ 1.0
+ + params[0] * theta**2
+ + params[1] * theta**4
+ + params[2] * theta**6
+ + params[3] * theta**8
+ )
+ mapx = fx * x1 * r + width // 2
+ mapy = fy * y1 * r + height // 2
+
+ # Use mask to define ROI
+ mask = np.logical_and(
+ np.logical_and(mapx > 0, mapy > 0),
+ np.logical_and(mapx < width - 1, mapy < height - 1),
+ )
+ y_indices, x_indices = np.nonzero(mask)
+ y_min, y_max = y_indices.min(), y_indices.max() + 1
+ x_min, x_max = x_indices.min(), x_indices.max() + 1
+ mask = mask[y_min:y_max, x_min:x_max]
+ K_undist = K.copy()
+ K_undist[0, 2] -= x_min
+ K_undist[1, 2] -= y_min
+ roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min]
+ else:
+ assert_never(camtype)
+
+ self.mapx_dict[camera_id] = mapx
+ self.mapy_dict[camera_id] = mapy
+ self.Ks_dict[camera_id] = K_undist
+ self.roi_undist_dict[camera_id] = roi_undist
+ self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3])
+ self.mask_dict[camera_id] = mask
+
+ # size of the scene measured by cameras
+ camera_locations = camtoworlds[:, :3, 3]
+ scene_center = np.mean(camera_locations, axis=0)
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
+ self.scene_scale = np.max(dists)
+
+
+class Dataset:
+ """A simple dataset class."""
+
+ def __init__(
+ self,
+ parser: Parser,
+ split: str = "train",
+ patch_size: Optional[int] = None,
+ load_depths: bool = False,
+ ):
+ self.parser = parser
+ self.split = split
+ self.patch_size = patch_size
+ self.load_depths = load_depths
+ indices = np.arange(len(self.parser.image_names))
+ if split == "train":
+ self.indices = indices[indices % self.parser.test_every != 0]
+ else:
+ self.indices = indices[indices % self.parser.test_every == 0]
+
+ def __len__(self):
+ return len(self.indices)
+
+ def __getitem__(self, item: int) -> Dict[str, Any]:
+ index = self.indices[item]
+ image = imageio.imread(self.parser.image_paths[index])[..., :3]
+ camera_id = self.parser.camera_ids[index]
+ K = self.parser.Ks_dict[camera_id].copy() # undistorted K
+ params = self.parser.params_dict[camera_id]
+ camtoworlds = self.parser.camtoworlds[index]
+ mask = self.parser.mask_dict[camera_id]
+
+ if len(params) > 0:
+ # Images are distorted. Undistort them.
+ mapx, mapy = (
+ self.parser.mapx_dict[camera_id],
+ self.parser.mapy_dict[camera_id],
+ )
+ image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
+ x, y, w, h = self.parser.roi_undist_dict[camera_id]
+ image = image[y : y + h, x : x + w]
+
+ if self.patch_size is not None:
+ # Random crop.
+ h, w = image.shape[:2]
+ x = np.random.randint(0, max(w - self.patch_size, 1))
+ y = np.random.randint(0, max(h - self.patch_size, 1))
+ image = image[y : y + self.patch_size, x : x + self.patch_size]
+ K[0, 2] -= x
+ K[1, 2] -= y
+
+ data = {
+ "K": torch.from_numpy(K).float(),
+ "camtoworld": torch.from_numpy(camtoworlds).float(),
+ "image": torch.from_numpy(image).float(),
+ "image_id": item, # the index of the image in the dataset
+ }
+ if mask is not None:
+ data["mask"] = torch.from_numpy(mask).bool()
+
+ if self.load_depths:
+ # projected points to image plane to get depths
+ worldtocams = np.linalg.inv(camtoworlds)
+ image_name = self.parser.image_names[index]
+ point_indices = self.parser.point_indices[image_name]
+ points_world = self.parser.points[point_indices]
+ points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
+ points_proj = (K @ points_cam.T).T
+ points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
+ depths = points_cam[:, 2] # (M,)
+ # filter out points outside the image
+ selector = (
+ (points[:, 0] >= 0)
+ & (points[:, 0] < image.shape[1])
+ & (points[:, 1] >= 0)
+ & (points[:, 1] < image.shape[0])
+ & (depths > 0)
+ )
+ points = points[selector]
+ depths = depths[selector]
+ data["points"] = torch.from_numpy(points).float()
+ data["depths"] = torch.from_numpy(depths).float()
+
+ return data
+
+
+if __name__ == "__main__":
+ import argparse
+
+ import imageio.v2 as imageio
+ import tqdm
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data_dir", type=str, default="data/360_v2/garden")
+ parser.add_argument("--factor", type=int, default=4)
+ args = parser.parse_args()
+
+ # Parse COLMAP data.
+ parser = Parser(
+ data_dir=args.data_dir, factor=args.factor, normalize=True, test_every=8
+ )
+ dataset = Dataset(parser, split="train", load_depths=True)
+ print(f"Dataset: {len(dataset)} images.")
+
+ writer = imageio.get_writer("results/points.mp4", fps=30)
+ for data in tqdm.tqdm(dataset, desc="Plotting points"):
+ image = data["image"].numpy().astype(np.uint8)
+ points = data["points"].numpy()
+ depths = data["depths"].numpy()
+ for x, y in points:
+ cv2.circle(image, (int(x), int(y)), 2, (255, 0, 0), -1)
+ writer.append_data(image)
+ writer.close()
diff --git a/pruning_utils/lib_bilagrid.py b/pruning_utils/lib_bilagrid.py
new file mode 100644
index 0000000..1b10a20
--- /dev/null
+++ b/pruning_utils/lib_bilagrid.py
@@ -0,0 +1,573 @@
+# # Copyright 2024 Yuehao Wang (https://github.com/yuehaowang). This part of code is borrowed form ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This is a standalone PyTorch implementation of 3D bilateral grid and CP-decomposed 4D bilateral grid.
+To use this module, you can download the "lib_bilagrid.py" file and simply put it in your project directory.
+
+For the details, please check our research project: ["Bilateral Guided Radiance Field Processing"](https://bilarfpro.github.io/).
+
+#### Dependencies
+
+In addition to PyTorch and Numpy, please install [tensorly](https://github.com/tensorly/tensorly).
+We have tested this module on Python 3.9.18, PyTorch 2.0.1 (CUDA 11), tensorly 0.8.1, and Numpy 1.25.2.
+
+#### Overview
+
+- For bilateral guided training, you need to construct a `BilateralGrid` instance, which can hold multiple bilateral grids
+ for input views. Then, use `slice` function to obtain transformed RGB output and the corresponding affine transformations.
+
+- For bilateral guided finishing, you need to instantiate a `BilateralGridCP4D` object and use `slice4d`.
+
+#### Examples
+
+- Bilateral grid for approximating ISP:
+
+
+
+- Low-rank 4D bilateral grid for MR enhancement:
+
+
+
+
+Below is the API reference.
+
+"""
+
+import tensorly as tl
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+tl.set_backend("pytorch")
+
+
+def color_correct(
+ img: torch.Tensor, ref: torch.Tensor, num_iters: int = 5, eps: float = 0.5 / 255
+) -> torch.Tensor:
+ """
+ Warp `img` to match the colors in `ref_img` using iterative color matching.
+
+ This function performs color correction by warping the colors of the input image
+ to match those of a reference image. It uses a least squares method to find a
+ transformation that maps the input image's colors to the reference image's colors.
+
+ The algorithm iteratively solves a system of linear equations, updating the set of
+ unsaturated pixels in each iteration. This approach helps handle non-linear color
+ transformations and reduces the impact of clipping.
+
+ Args:
+ img (torch.Tensor): Input image to be color corrected. Shape: [..., num_channels]
+ ref (torch.Tensor): Reference image to match colors. Shape: [..., num_channels]
+ num_iters (int, optional): Number of iterations for the color matching process.
+ Default is 5.
+ eps (float, optional): Small value to determine the range of unclipped pixels.
+ Default is 0.5 / 255.
+
+ Returns:
+ torch.Tensor: Color corrected image with the same shape as the input image.
+
+ Note:
+ - Both input and reference images should be in the range [0, 1].
+ - The function works with any number of channels, but typically used with 3 (RGB).
+ """
+ if img.shape[-1] != ref.shape[-1]:
+ raise ValueError(
+ f"img's {img.shape[-1]} and ref's {ref.shape[-1]} channels must match"
+ )
+ num_channels = img.shape[-1]
+ img_mat = img.reshape([-1, num_channels])
+ ref_mat = ref.reshape([-1, num_channels])
+
+ def is_unclipped(z):
+ return (z >= eps) & (z <= 1 - eps) # z \in [eps, 1-eps].
+
+ mask0 = is_unclipped(img_mat)
+ # Because the set of saturated pixels may change after solving for a
+ # transformation, we repeatedly solve a system `num_iters` times and update
+ # our estimate of which pixels are saturated.
+ for _ in range(num_iters):
+ # Construct the left hand side of a linear system that contains a quadratic
+ # expansion of each pixel of `img`.
+ a_mat = []
+ for c in range(num_channels):
+ a_mat.append(img_mat[:, c : (c + 1)] * img_mat[:, c:]) # Quadratic term.
+ a_mat.append(img_mat) # Linear term.
+ a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term.
+ a_mat = torch.cat(a_mat, dim=-1)
+ warp = []
+ for c in range(num_channels):
+ # Construct the right hand side of a linear system containing each color
+ # of `ref`.
+ b = ref_mat[:, c]
+ # Ignore rows of the linear system that were saturated in the input or are
+ # saturated in the current corrected color estimate.
+ mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b)
+ ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat))
+ mb = torch.where(mask, b, torch.zeros_like(b))
+ w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0]
+ assert torch.all(torch.isfinite(w))
+ warp.append(w)
+ warp = torch.stack(warp, dim=-1)
+ # Apply the warp to update img_mat.
+ img_mat = torch.clip(torch.matmul(a_mat, warp), 0, 1)
+ corrected_img = torch.reshape(img_mat, img.shape)
+ return corrected_img
+
+
+def bilateral_grid_tv_loss(model, config):
+ """Computes total variations of bilateral grids."""
+ total_loss = 0.0
+
+ for bil_grids in model.bil_grids:
+ total_loss += config.bilgrid_tv_loss_mult * total_variation_loss(
+ bil_grids.grids
+ )
+
+ return total_loss
+
+
+def color_affine_transform(affine_mats, rgb):
+ """Applies color affine transformations.
+
+ Args:
+ affine_mats (torch.Tensor): Affine transformation matrices. Supported shape: $(..., 3, 4)$.
+ rgb (torch.Tensor): Input RGB values. Supported shape: $(..., 3)$.
+
+ Returns:
+ Output transformed colors of shape $(..., 3)$.
+ """
+ return (
+ torch.matmul(affine_mats[..., :3], rgb.unsqueeze(-1)).squeeze(-1)
+ + affine_mats[..., 3]
+ )
+
+
+def _num_tensor_elems(t):
+ return max(torch.prod(torch.tensor(t.size()[1:]).float()).item(), 1.0)
+
+
+def total_variation_loss(x): # noqa: F811
+ """Returns total variation on multi-dimensional tensors.
+
+ Args:
+ x (torch.Tensor): The input tensor with shape $(B, C, ...)$, where $B$ is the batch size and $C$ is the channel size.
+ """
+ batch_size = x.shape[0]
+ tv = 0
+ for i in range(2, len(x.shape)):
+ n_res = x.shape[i]
+ idx1 = torch.arange(1, n_res, device=x.device)
+ idx2 = torch.arange(0, n_res - 1, device=x.device)
+ x1 = x.index_select(i, idx1)
+ x2 = x.index_select(i, idx2)
+ count = _num_tensor_elems(x1)
+ tv += torch.pow((x1 - x2), 2).sum() / count
+ return tv / batch_size
+
+
+def slice(bil_grids, xy, rgb, grid_idx):
+ """Slices a batch of 3D bilateral grids by pixel coordinates `xy` and gray-scale guidances of pixel colors `rgb`.
+
+ Supports 2-D, 3-D, and 4-D input shapes. The first dimension of the input is the batch size
+ and the last dimension is 2 for `xy`, 3 for `rgb`, and 1 for `grid_idx`.
+
+ The return value is a dictionary containing the affine transformations `affine_mats` sliced from bilateral grids and
+ the output color `rgb_out` after applying the afffine transformations.
+
+ In the 2-D input case, `xy` is a $(N, 2)$ tensor, `rgb` is a $(N, 3)$ tensor, and `grid_idx` is a $(N, 1)$ tensor.
+ Then `affine_mats[i]` can be obtained via slicing the bilateral grid indexed at `grid_idx[i]` by `xy[i, :]` and `rgb2gray(rgb[i, :])`.
+ For 3-D and 4-D input cases, the behavior of indexing bilateral grids and coordinates is the same with the 2-D case.
+
+ .. note::
+ This function can be regarded as a wrapper of `color_affine_transform` and `BilateralGrid` with a slight performance improvement.
+ When `grid_idx` contains a unique index, only a single bilateral grid will used during the slicing. In this case, this function will not
+ perform tensor indexing to avoid data copy and extra memory
+ (see [this](https://discuss.pytorch.org/t/does-indexing-a-tensor-return-a-copy-of-it/164905)).
+
+ Args:
+ bil_grids (`BilateralGrid`): An instance of $N$ bilateral grids.
+ xy (torch.Tensor): The x-y coordinates of shape $(..., 2)$ in the range of $[0,1]$.
+ rgb (torch.Tensor): The RGB values of shape $(..., 3)$ for computing the guidance coordinates, ranging in $[0,1]$.
+ grid_idx (torch.Tensor): The indices of bilateral grids for each slicing. Shape: $(..., 1)$.
+
+ Returns:
+ A dictionary with keys and values as follows:
+ ```
+ {
+ "rgb": Transformed RGB colors. Shape: (..., 3),
+ "rgb_affine_mats": The sliced affine transformation matrices from bilateral grids. Shape: (..., 3, 4)
+ }
+ ```
+ """
+
+ sh_ = rgb.shape
+
+ grid_idx_unique = torch.unique(grid_idx)
+ if len(grid_idx_unique) == 1:
+ # All pixels are from a single view.
+ grid_idx = grid_idx_unique # (1,)
+ xy = xy.unsqueeze(0) # (1, ..., 2)
+ rgb = rgb.unsqueeze(0) # (1, ..., 3)
+ else:
+ # Pixels are randomly sampled from different views.
+ if len(grid_idx.shape) == 4:
+ grid_idx = grid_idx[:, 0, 0, 0] # (chunk_size,)
+ elif len(grid_idx.shape) == 3:
+ grid_idx = grid_idx[:, 0, 0] # (chunk_size,)
+ elif len(grid_idx.shape) == 2:
+ grid_idx = grid_idx[:, 0] # (chunk_size,)
+ else:
+ raise ValueError(
+ "The input to bilateral grid slicing is not supported yet."
+ )
+
+ affine_mats = bil_grids(xy, rgb, grid_idx)
+ rgb = color_affine_transform(affine_mats, rgb)
+
+ return {
+ "rgb": rgb.reshape(*sh_),
+ "rgb_affine_mats": affine_mats.reshape(
+ *sh_[:-1], affine_mats.shape[-2], affine_mats.shape[-1]
+ ),
+ }
+
+
+class BilateralGrid(nn.Module):
+ """Class for 3D bilateral grids.
+
+ Holds one or more than one bilateral grids.
+ """
+
+ def __init__(self, num, grid_X=16, grid_Y=16, grid_W=8):
+ """
+ Args:
+ num (int): The number of bilateral grids (i.e., the number of views).
+ grid_X (int): Defines grid width $W$.
+ grid_Y (int): Defines grid height $H$.
+ grid_W (int): Defines grid guidance dimension $L$.
+ """
+ super(BilateralGrid, self).__init__()
+
+ self.grid_width = grid_X
+ """Grid width. Type: int."""
+ self.grid_height = grid_Y
+ """Grid height. Type: int."""
+ self.grid_guidance = grid_W
+ """Grid guidance dimension. Type: int."""
+
+ # Initialize grids.
+ grid = self._init_identity_grid()
+ self.grids = nn.Parameter(grid.tile(num, 1, 1, 1, 1)) # (N, 12, L, H, W)
+ """ A 5-D tensor of shape $(N, 12, L, H, W)$."""
+
+ # Weights of BT601 RGB-to-gray.
+ self.register_buffer("rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]]))
+ self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
+ """ A function that converts RGB to gray-scale guidance in $[-1, 1]$."""
+
+ def _init_identity_grid(self):
+ grid = torch.tensor(
+ [
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ ]
+ ).float()
+ grid = grid.repeat(
+ [self.grid_guidance * self.grid_height * self.grid_width, 1]
+ ) # (L * H * W, 12)
+ grid = grid.reshape(
+ 1, self.grid_guidance, self.grid_height, self.grid_width, -1
+ ) # (1, L, H, W, 12)
+ grid = grid.permute(0, 4, 1, 2, 3) # (1, 12, L, H, W)
+ return grid
+
+ def tv_loss(self):
+ """Computes and returns total variation loss on the bilateral grids."""
+ return total_variation_loss(self.grids)
+
+ def forward(self, grid_xy, rgb, idx=None):
+ """Bilateral grid slicing. Supports 2-D, 3-D, 4-D, and 5-D input.
+ For the 2-D, 3-D, and 4-D cases, please refer to `slice`.
+ For the 5-D cases, `idx` will be unused and the first dimension of `xy` should be
+ equal to the number of bilateral grids. Then this function becomes PyTorch's
+ [`F.grid_sample`](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
+
+ Args:
+ grid_xy (torch.Tensor): The x-y coordinates in the range of $[0,1]$.
+ rgb (torch.Tensor): The RGB values in the range of $[0,1]$.
+ idx (torch.Tensor): The bilateral grid indices.
+
+ Returns:
+ Sliced affine matrices of shape $(..., 3, 4)$.
+ """
+
+ grids = self.grids
+ input_ndims = len(grid_xy.shape)
+ assert len(rgb.shape) == input_ndims
+
+ if input_ndims > 1 and input_ndims < 5:
+ # Convert input into 5D
+ for i in range(5 - input_ndims):
+ grid_xy = grid_xy.unsqueeze(1)
+ rgb = rgb.unsqueeze(1)
+ assert idx is not None
+ elif input_ndims != 5:
+ raise ValueError(
+ "Bilateral grid slicing only takes either 2D, 3D, 4D and 5D inputs"
+ )
+
+ grids = self.grids
+ if idx is not None:
+ grids = grids[idx]
+ assert grids.shape[0] == grid_xy.shape[0]
+
+ # Generate slicing coordinates.
+ grid_xy = (grid_xy - 0.5) * 2 # Rescale to [-1, 1].
+ grid_z = self.rgb2gray(rgb)
+
+ # print(grid_xy.shape, grid_z.shape)
+ # exit()
+ grid_xyz = torch.cat([grid_xy, grid_z], dim=-1) # (N, m, h, w, 3)
+
+ affine_mats = F.grid_sample(
+ grids, grid_xyz, mode="bilinear", align_corners=True, padding_mode="border"
+ ) # (N, 12, m, h, w)
+ affine_mats = affine_mats.permute(0, 2, 3, 4, 1) # (N, m, h, w, 12)
+ affine_mats = affine_mats.reshape(
+ *affine_mats.shape[:-1], 3, 4
+ ) # (N, m, h, w, 3, 4)
+
+ for _ in range(5 - input_ndims):
+ affine_mats = affine_mats.squeeze(1)
+
+ return affine_mats
+
+
+def slice4d(bil_grid4d, xyz, rgb):
+ """Slices a 4D bilateral grid by point coordinates `xyz` and gray-scale guidances of radiance colors `rgb`.
+
+ Args:
+ bil_grid4d (`BilateralGridCP4D`): The input 4D bilateral grid.
+ xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
+ rgb (torch.Tensor): The RGB values with shape $(..., 3)$.
+
+ Returns:
+ A dictionary with keys and values as follows:
+ ```
+ {
+ "rgb": Transformed radiance RGB colors. Shape: (..., 3),
+ "rgb_affine_mats": The sliced affine transformation matrices from the 4D bilateral grid. Shape: (..., 3, 4)
+ }
+ ```
+ """
+
+ affine_mats = bil_grid4d(xyz, rgb)
+ rgb = color_affine_transform(affine_mats, rgb)
+
+ return {"rgb": rgb, "rgb_affine_mats": affine_mats}
+
+
+class _ScaledTanh(nn.Module):
+ def __init__(self, s=2.0):
+ super().__init__()
+ self.scaler = s
+
+ def forward(self, x):
+ return torch.tanh(self.scaler * x)
+
+
+class BilateralGridCP4D(nn.Module):
+ """Class for low-rank 4D bilateral grids."""
+
+ def __init__(
+ self,
+ grid_X=16,
+ grid_Y=16,
+ grid_Z=16,
+ grid_W=8,
+ rank=5,
+ learn_gray=True,
+ gray_mlp_width=8,
+ gray_mlp_depth=2,
+ init_noise_scale=1e-6,
+ bound=2.0,
+ ):
+ """
+ Args:
+ grid_X (int): Defines grid width.
+ grid_Y (int): Defines grid height.
+ grid_Z (int): Defines grid depth.
+ grid_W (int): Defines grid guidance dimension.
+ rank (int): Rank of the 4D bilateral grid.
+ learn_gray (bool): If True, an MLP will be learned to convert RGB colors to gray-scale guidances.
+ gray_mlp_width (int): The MLP width for learnable guidance.
+ gray_mlp_depth (int): The number of MLP layers for learnable guidance.
+ init_noise_scale (float): The noise scale of the initialized factors.
+ bound (float): The bound of the xyz coordinates.
+ """
+ super(BilateralGridCP4D, self).__init__()
+
+ self.grid_X = grid_X
+ """Grid width. Type: int."""
+ self.grid_Y = grid_Y
+ """Grid height. Type: int."""
+ self.grid_Z = grid_Z
+ """Grid depth. Type: int."""
+ self.grid_W = grid_W
+ """Grid guidance dimension. Type: int."""
+ self.rank = rank
+ """Rank of the 4D bilateral grid. Type: int."""
+ self.learn_gray = learn_gray
+ """Flags of learnable guidance is used. Type: bool."""
+ self.gray_mlp_width = gray_mlp_width
+ """The MLP width for learnable guidance. Type: int."""
+ self.gray_mlp_depth = gray_mlp_depth
+ """The MLP depth for learnable guidance. Type: int."""
+ self.init_noise_scale = init_noise_scale
+ """The noise scale of the initialized factors. Type: float."""
+ self.bound = bound
+ """The bound of the xyz coordinates. Type: float."""
+
+ self._init_cp_factors_parafac()
+
+ self.rgb2gray = None
+ """ A function that converts RGB to gray-scale guidances in $[-1, 1]$.
+ If `learn_gray` is True, this will be an MLP network."""
+
+ if self.learn_gray:
+
+ def rgb2gray_mlp_linear(layer):
+ return nn.Linear(
+ self.gray_mlp_width,
+ self.gray_mlp_width if layer < self.gray_mlp_depth - 1 else 1,
+ )
+
+ def rgb2gray_mlp_actfn(_):
+ return nn.ReLU(inplace=True)
+
+ self.rgb2gray = nn.Sequential(
+ *(
+ [nn.Linear(3, self.gray_mlp_width)]
+ + [
+ nn_module(layer)
+ for layer in range(1, self.gray_mlp_depth)
+ for nn_module in [rgb2gray_mlp_actfn, rgb2gray_mlp_linear]
+ ]
+ + [_ScaledTanh(2.0)]
+ )
+ )
+ else:
+ # Weights of BT601/BT470 RGB-to-gray.
+ self.register_buffer(
+ "rgb2gray_weight", torch.Tensor([[0.299, 0.587, 0.114]])
+ )
+ self.rgb2gray = lambda rgb: (rgb @ self.rgb2gray_weight.T) * 2.0 - 1.0
+
+ def _init_identity_grid(self):
+ grid = torch.tensor(
+ [
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 1.0,
+ 0,
+ ]
+ ).float()
+ grid = grid.repeat([self.grid_W * self.grid_Z * self.grid_Y * self.grid_X, 1])
+ grid = grid.reshape(self.grid_W, self.grid_Z, self.grid_Y, self.grid_X, -1)
+ grid = grid.permute(4, 0, 1, 2, 3) # (12, grid_W, grid_Z, grid_Y, grid_X)
+ return grid
+
+ def _init_cp_factors_parafac(self):
+ # Initialize identity grids.
+ init_grids = self._init_identity_grid()
+ # Random noises are added to avoid singularity.
+ init_grids = torch.randn_like(init_grids) * self.init_noise_scale + init_grids
+ from tensorly.decomposition import parafac
+
+ # Initialize grid CP factors
+ _, facs = parafac(init_grids.clone().detach(), rank=self.rank)
+
+ self.num_facs = len(facs)
+
+ self.fac_0 = nn.Linear(facs[0].shape[0], facs[0].shape[1], bias=False)
+ self.fac_0.weight = nn.Parameter(facs[0]) # (12, rank)
+
+ for i in range(1, self.num_facs):
+ fac = facs[i].T # (rank, grid_size)
+ fac = fac.view(1, fac.shape[0], fac.shape[1], 1) # (1, rank, grid_size, 1)
+ self.register_buffer(f"fac_{i}_init", fac)
+
+ fac_resid = torch.zeros_like(fac)
+ self.register_parameter(f"fac_{i}", nn.Parameter(fac_resid))
+
+ def tv_loss(self):
+ """Computes and returns total variation loss on the factors of the low-rank 4D bilateral grids."""
+
+ total_loss = 0
+ for i in range(1, self.num_facs):
+ fac = self.get_parameter(f"fac_{i}")
+ total_loss += total_variation_loss(fac)
+
+ return total_loss
+
+ def forward(self, xyz, rgb):
+ """Low-rank 4D bilateral grid slicing.
+
+ Args:
+ xyz (torch.Tensor): The xyz coordinates with shape $(..., 3)$.
+ rgb (torch.Tensor): The corresponding RGB values with shape $(..., 3)$.
+
+ Returns:
+ Sliced affine matrices with shape $(..., 3, 4)$.
+ """
+ sh_ = xyz.shape
+ xyz = xyz.reshape(-1, 3) # flatten (N, 3)
+ rgb = rgb.reshape(-1, 3) # flatten (N, 3)
+
+ xyz = xyz / self.bound
+ assert self.rgb2gray is not None
+ gray = self.rgb2gray(rgb)
+ xyzw = torch.cat([xyz, gray], dim=-1) # (N, 4)
+ xyzw = xyzw.transpose(0, 1) # (4, N)
+ coords = torch.stack([torch.zeros_like(xyzw), xyzw], dim=-1) # (4, N, 2)
+ coords = coords.unsqueeze(1) # (4, 1, N, 2)
+
+ coef = 1.0
+ for i in range(1, self.num_facs):
+ fac = self.get_parameter(f"fac_{i}") + self.get_buffer(f"fac_{i}_init")
+ coef = coef * F.grid_sample(
+ fac, coords[[i - 1]], align_corners=True, padding_mode="border"
+ ) # [1, rank, 1, N]
+ coef = coef.squeeze([0, 2]).transpose(0, 1) # (N, rank) #type: ignore
+ mat = self.fac_0(coef)
+ return mat.reshape(*sh_[:-1], 3, 4)
diff --git a/pruning_utils/nerf.py b/pruning_utils/nerf.py
new file mode 100644
index 0000000..07e4693
--- /dev/null
+++ b/pruning_utils/nerf.py
@@ -0,0 +1,229 @@
+import os
+import json
+import torch
+from typing import Any, Dict, Optional
+import imageio.v2 as imageio
+import numpy as np
+import cv2
+
+from .normalize import (
+ align_principle_axes,
+ similarity_from_cameras,
+ transform_cameras,
+)
+
+
+def convert_opengl_to_colmap(R):
+ M_flip = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
+ return R @ M_flip
+
+
+class NerfParser:
+ """NeRF Studio parser (transforms.json)."""
+ def __init__(self, data_dir: str, factor: int = 1, normalize: bool = False, test_every: int = 8):
+ self.data_dir = data_dir
+ self.factor = factor
+ self.normalize = normalize
+ self.test_every = test_every
+
+ # Load transforms.json
+ transform_path = os.path.join(data_dir, "transforms.json")
+ if not os.path.exists(transform_path):
+ raise FileNotFoundError(f"Could not find transforms.json at {transform_path}")
+ with open(transform_path, "r") as f:
+ meta = json.load(f)
+
+ # Store camera poses and intrinsics
+ camtoworlds = []
+ image_paths = []
+ depth_paths = []
+
+ camera_ids = []
+ image_names = []
+
+ fl_x = meta["fl_x"]
+ fl_y = meta["fl_y"]
+ cx = meta["cx"]
+ cy = meta["cy"]
+ w = meta["w"]
+ h = meta["h"]
+ fx, fy = fl_x / factor, fl_y / factor
+ cx, cy = cx / factor, cy / factor
+ width, height = w // factor, h // factor
+
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
+ cam_id = 0 # Single intrinsics shared across all images
+
+ for i, frame in enumerate(meta["frames"]):
+ c2w = np.array(frame["transform_matrix"], dtype=np.float32)
+ camtoworlds.append(c2w)
+
+
+
+ rel_image_path = frame["file_path"]
+ if not rel_image_path.lower().endswith((".png", ".jpg")):
+ rel_image_path += ".png" # default fallback
+
+ image_path = os.path.join(data_dir, rel_image_path)
+ if not os.path.exists(image_path):
+ raise FileNotFoundError(f"Image {image_path} not found.")
+ image_paths.append(image_path)
+ image_names.append(os.path.basename(rel_image_path))
+ camera_ids.append(cam_id)
+
+
+
+ rel_depth_path = frame.get("depth_file_path")
+ if rel_depth_path is None:
+ raise ValueError(f"No depth_file_path found for frame {i}")
+ if not rel_depth_path.lower().endswith((".png", ".jpg", ".exr")):
+ rel_depth_path += ".png"
+ depth_path = os.path.join(data_dir, rel_depth_path)
+ if not os.path.exists(depth_path):
+ raise FileNotFoundError(f"Depth image not found: {depth_path}")
+ depth_paths.append(depth_path)
+
+ camtoworlds = np.stack(camtoworlds, axis=0)
+ image_paths = np.array(image_paths)
+ image_names = np.array(image_names)
+ depth_paths = np.array(depth_paths)
+
+
+
+
+ # Normalize world space if requested
+ if normalize:
+ T1 = similarity_from_cameras(camtoworlds)
+ camtoworlds = transform_cameras(T1, camtoworlds)
+
+ T2 = align_principle_axes(camtoworlds[:, :3, 3])
+ camtoworlds = transform_cameras(T2, camtoworlds)
+
+ transform = T2 @ T1
+ else:
+ transform = np.eye(4)
+
+ self.image_names = list(image_names)
+ self.image_paths = list(image_paths)
+ self.depth_paths = list(depth_paths)
+ self.camtoworlds = camtoworlds
+ self.camera_ids = camera_ids
+ self.Ks_dict = {cam_id: K}
+ self.params_dict = {cam_id: np.array([], dtype=np.float32)} # no distortion
+ self.imsize_dict = {cam_id: (width, height)}
+ self.mask_dict = {cam_id: None}
+ self.points = np.zeros((0, 3), dtype=np.float32)
+ self.points_err = np.zeros((0,), dtype=np.float32)
+ self.points_rgb = np.zeros((0, 3), dtype=np.uint8)
+ self.point_indices = {} # Not available in NeRF Studio format
+ self.transform = transform
+
+ # Resize intrinsics if actual image differs from JSON
+ actual_image = imageio.imread(self.image_paths[0])[..., :3]
+ actual_height, actual_width = actual_image.shape[:2]
+ s_height, s_width = actual_height / height, actual_width / width
+ for camera_id, K in self.Ks_dict.items():
+ K[0, :] *= s_width
+ K[1, :] *= s_height
+ self.Ks_dict[camera_id] = K
+ self.imsize_dict[camera_id] = (int(width * s_width), int(height * s_height))
+
+ # NeRF Studio has no distortion, so no undistortion mapping needed
+ self.mapx_dict = {}
+ self.mapy_dict = {}
+ self.roi_undist_dict = {}
+
+ # Estimate scene scale from camera positions
+ camera_locations = self.camtoworlds[:, :3, 3]
+ scene_center = np.mean(camera_locations, axis=0)
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
+ self.scene_scale = float(np.max(dists))
+
+
+ print(f"[Parser] Loaded {len(self.image_paths)} NeRF Studio frames.")
+
+ self.camtoworlds[:, :3, :3] = convert_opengl_to_colmap(self.camtoworlds[:, :3, :3])
+
+
+class NerfDataset:
+ """A simple dataset class."""
+
+ def __init__(
+ self,
+ parser: NerfParser,
+ split: str = "train",
+ patch_size: Optional[int] = None,
+ load_depths: bool = False,
+ ):
+ self.parser = parser
+ self.split = split
+ self.patch_size = patch_size
+ self.load_depths = load_depths
+ indices = np.arange(len(self.parser.image_names))
+ if split == "train":
+ self.indices = indices[indices % self.parser.test_every != 0]
+ else:
+ self.indices = indices[indices % self.parser.test_every == 0]
+
+
+ def __len__(self):
+ return len(self.indices)
+
+ def __getitem__(self, item: int) -> Dict[str, Any]:
+ index = self.indices[item]
+ image = imageio.imread(self.parser.image_paths[index])[..., :3]
+ camera_id = self.parser.camera_ids[index]
+ K = self.parser.Ks_dict[camera_id].copy() # undistorted K
+ params = self.parser.params_dict[camera_id]
+ camtoworlds = self.parser.camtoworlds[index]
+ mask = self.parser.mask_dict[camera_id]
+
+ if len(params) > 0:
+ # Images are distorted. Undistort them.
+ mapx, mapy = (
+ self.parser.mapx_dict[camera_id],
+ self.parser.mapy_dict[camera_id],
+ )
+ image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
+ x, y, w, h = self.parser.roi_undist_dict[camera_id]
+ image = image[y : y + h, x : x + w]
+
+ if self.patch_size is not None:
+ # Random crop.
+ h, w = image.shape[:2]
+ x = np.random.randint(0, max(w - self.patch_size, 1))
+ y = np.random.randint(0, max(h - self.patch_size, 1))
+ image = image[y : y + self.patch_size, x : x + self.patch_size]
+ K[0, 2] -= x
+ K[1, 2] -= y
+
+ data = {
+ "K": torch.from_numpy(K).float(),
+ "camtoworld": torch.from_numpy(camtoworlds).float(),
+ "image": torch.from_numpy(image).float(),
+ "image_id": item, # the index of the image in the dataset
+ }
+ if mask is not None:
+ data["mask"] = torch.from_numpy(mask).bool()
+
+ if self.load_depths:
+ # Get the full depth path from the parser
+ depth_filename = self.parser.depth_paths[index]
+
+ if not os.path.exists(depth_filename):
+ raise FileNotFoundError(f"Depth file not found: {depth_filename}")
+
+ # Load depth
+ depth = imageio.imread(depth_filename).astype(np.float32)
+
+ # Apply patching if needed
+ if self.patch_size is not None:
+ depth = depth[y : y + self.patch_size, x : x + self.patch_size]
+
+ if depth.shape[:2] != image.shape[:2]:
+ raise ValueError(f"Depth shape {depth.shape} doesn't match image shape {image.shape}")
+
+ data["depth"] = torch.from_numpy(depth).float()
+
+ return data
+
diff --git a/pruning_utils/normalize.py b/pruning_utils/normalize.py
new file mode 100644
index 0000000..e0c4b20
--- /dev/null
+++ b/pruning_utils/normalize.py
@@ -0,0 +1,154 @@
+"""
+===========================================================================
+Unmodified code from gsplat examples
+Original source:
+https://github.com/nerfstudio-project/gsplat/blob/main/examples/datasets/normalize.py
+License: Apache License 2.0
+----------------------------------------------------------------------------
+This file was copied verbatim from the gsplat repository. No modifications were made.
+===========================================================================
+"""
+
+import numpy as np
+
+
+def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
+ """
+ reference: nerf-factory
+ Get a similarity transform to normalize dataset
+ from c2w (OpenCV convention) cameras
+ :param c2w: (N, 4)
+ :return T (4,4) , scale (float)
+ """
+ t = c2w[:, :3, 3]
+ R = c2w[:, :3, :3]
+
+ # (1) Rotate the world so that z+ is the up axis
+ # we estimate the up axis by averaging the camera up axes
+ ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
+ world_up = np.mean(ups, axis=0)
+ world_up /= np.linalg.norm(world_up)
+
+ up_camspace = np.array([0.0, -1.0, 0.0])
+ c = (up_camspace * world_up).sum()
+ cross = np.cross(world_up, up_camspace)
+ skew = np.array(
+ [
+ [0.0, -cross[2], cross[1]],
+ [cross[2], 0.0, -cross[0]],
+ [-cross[1], cross[0], 0.0],
+ ]
+ )
+ if c > -1:
+ R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
+ else:
+ # In the unlikely case the original data has y+ up axis,
+ # rotate 180-deg about x axis
+ R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
+
+ # R_align = np.eye(3) # DEBUG
+ R = R_align @ R
+ fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
+ t = (R_align @ t[..., None])[..., 0]
+
+ # (2) Recenter the scene.
+ if center_method == "focus":
+ # find the closest point to the origin for each camera's center ray
+ nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
+ translate = -np.median(nearest, axis=0)
+ elif center_method == "poses":
+ # use center of the camera positions
+ translate = -np.median(t, axis=0)
+ else:
+ raise ValueError(f"Unknown center_method {center_method}")
+
+ transform = np.eye(4)
+ transform[:3, 3] = translate
+ transform[:3, :3] = R_align
+
+ # (3) Rescale the scene using camera distances
+ scale_fn = np.max if strict_scaling else np.median
+ scale = 1.0 / scale_fn(np.linalg.norm(t + translate, axis=-1))
+ transform[:3, :] *= scale
+
+ return transform
+
+
+def align_principle_axes(point_cloud):
+ # Compute centroid
+ centroid = np.median(point_cloud, axis=0)
+
+ # Translate point cloud to centroid
+ translated_point_cloud = point_cloud - centroid
+
+ # Compute covariance matrix
+ covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
+
+ # Compute eigenvectors and eigenvalues
+ eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
+
+ # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
+ # is the principal axis with the smallest eigenvalue.
+ sort_indices = eigenvalues.argsort()[::-1]
+ eigenvectors = eigenvectors[:, sort_indices]
+
+ # Check orientation of eigenvectors. If the determinant of the eigenvectors is
+ # negative, then we need to flip the sign of one of the eigenvectors.
+ if np.linalg.det(eigenvectors) < 0:
+ eigenvectors[:, 0] *= -1
+
+ # Create rotation matrix
+ rotation_matrix = eigenvectors.T
+
+ # Create SE(3) matrix (4x4 transformation matrix)
+ transform = np.eye(4)
+ transform[:3, :3] = rotation_matrix
+ transform[:3, 3] = -rotation_matrix @ centroid
+
+ return transform
+
+
+def transform_points(matrix, points):
+ """Transform points using an SE(3) matrix.
+
+ Args:
+ matrix: 4x4 SE(3) matrix
+ points: Nx3 array of points
+
+ Returns:
+ Nx3 array of transformed points
+ """
+ assert matrix.shape == (4, 4)
+ assert len(points.shape) == 2 and points.shape[1] == 3
+ return points @ matrix[:3, :3].T + matrix[:3, 3]
+
+
+def transform_cameras(matrix, camtoworlds):
+ """Transform cameras using an SE(3) matrix.
+
+ Args:
+ matrix: 4x4 SE(3) matrix
+ camtoworlds: Nx4x4 array of camera-to-world matrices
+
+ Returns:
+ Nx4x4 array of transformed camera-to-world matrices
+ """
+ assert matrix.shape == (4, 4)
+ assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
+ camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
+ scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
+ camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
+ return camtoworlds
+
+
+def normalize(camtoworlds, points=None):
+ T1 = similarity_from_cameras(camtoworlds)
+ camtoworlds = transform_cameras(T1, camtoworlds)
+ if points is not None:
+ points = transform_points(T1, points)
+ T2 = align_principle_axes(points)
+ camtoworlds = transform_cameras(T2, camtoworlds)
+ points = transform_points(T2, points)
+ return camtoworlds, points, T2 @ T1
+ else:
+ return camtoworlds, T1
diff --git a/pruning_utils/open_ply_pipeline.py b/pruning_utils/open_ply_pipeline.py
new file mode 100644
index 0000000..a6bd7d8
--- /dev/null
+++ b/pruning_utils/open_ply_pipeline.py
@@ -0,0 +1,232 @@
+import torch
+from pyntcloud import PyntCloud
+from typing import Dict, Union, Literal, Optional
+import os
+import pandas as pd
+
+def load_splats(
+ path: str,
+ device: Union[str, torch.device] = 'cpu'
+) -> Dict[str, torch.nn.Parameter]:
+ """
+ Load splat data from a .ply or .ckpt file.
+
+ Args:
+ path (str): Path to the file (.ply or .ckpt).
+ device (str): Device to load tensors on. Default is 'cpu'.
+
+ Returns:
+ splats (dict): Dictionary of torch.nn.Parameter containing splat data.
+ """
+ ext = os.path.splitext(path)[1].lower()
+
+ if ext == '.ply':
+ print("loading ply file")
+ cloud = PyntCloud.from_file(path)
+ data = cloud.points
+
+ # Mapping from column names in the .ply to model parameter names
+ mapping = {
+ "_model.gauss_params.means": ["x", "y", "z"],
+ "_model.gauss_params.opacities": "opacity",
+ "_model.gauss_params.quats": ["rot_0", 'rot_1', 'rot_2', 'rot_3'],
+ "_model.gauss_params.scales": ["scale_0", "scale_1", "scale_2"],
+ "_model.gauss_params.features_dc": [f"f_dc_{i}" for i in range(3)],
+ "_model.gauss_params.features_rest": [f"f_rest_{i}" for i in range(45)]
+ }
+
+ ckpt = {"pipeline": {}}
+
+ for key, value in mapping.items():
+ if isinstance(value, list):
+ ckpt["pipeline"][key] = torch.tensor(data[value].values, dtype=torch.float32)
+ else:
+ ckpt["pipeline"][key] = torch.tensor(data[value].values, dtype=torch.float32)
+
+ # Reshape tensors
+ ckpt["pipeline"]["_model.gauss_params.means"] = ckpt["pipeline"]["_model.gauss_params.means"].reshape(-1, 3)
+ ckpt["pipeline"]["_model.gauss_params.opacities"] = ckpt["pipeline"]["_model.gauss_params.opacities"].reshape(-1)
+ ckpt["pipeline"]["_model.gauss_params.scales"] = ckpt["pipeline"]["_model.gauss_params.scales"].reshape(-1, 3)
+ ckpt["pipeline"]["_model.gauss_params.features_dc"] = ckpt["pipeline"]["_model.gauss_params.features_dc"].reshape(-1, 1, 3)
+ ckpt["pipeline"]["_model.gauss_params.features_rest"] = ckpt["pipeline"]["_model.gauss_params.features_rest"].reshape(-1, 15, 3)
+
+
+
+
+ # Mapping for renaming
+ rename_map = {
+ "_model.gauss_params.means": "means",
+ "_model.gauss_params.opacities": "opacities",
+ "_model.gauss_params.quats": "quats",
+ "_model.gauss_params.scales": "scales",
+ "_model.gauss_params.features_dc": "sh0",
+ "_model.gauss_params.features_rest": "shN"
+ }
+
+ step = 0 # ply do not contain step information
+
+ # Convert to splats dictionary
+ splats = {
+ new_key: torch.nn.Parameter(ckpt["pipeline"][old_key] if new_key == "sh0"
+ else ckpt["pipeline"][old_key].squeeze() if new_key == "opacities"
+ else ckpt["pipeline"][old_key])
+ for old_key, new_key in rename_map.items()
+ }
+
+
+
+ elif ext == '.ckpt':
+ print("loading ckpt file")
+ ckpt = torch.load(path, map_location=device)
+
+ # Mapping for renaming
+ rename_map = {
+ "_model.gauss_params.means": "means",
+ "_model.gauss_params.opacities": "opacities",
+ "_model.gauss_params.quats": "quats",
+ "_model.gauss_params.scales": "scales",
+ "_model.gauss_params.features_dc": "sh0",
+ "_model.gauss_params.features_rest": "shN"
+ }
+
+ step = ckpt["step"]
+
+ # Convert to splats dictionary
+ splats = {
+ new_key: torch.nn.Parameter(ckpt["pipeline"][old_key].unsqueeze(1) if new_key == "sh0"
+ else ckpt["pipeline"][old_key].squeeze() if new_key == "opacities"
+ else ckpt["pipeline"][old_key])
+ for old_key, new_key in rename_map.items()
+ }
+
+
+ elif ext == '.pt':
+ print("loading pt file")
+
+ # Load and concatenate splats from checkpoints
+ path = [path]
+ ckpts = [
+ torch.load(file, map_location=device, weights_only=True) for file in path
+ ]
+
+ param_names = ckpts[0]["splats"].keys()
+
+ step = ckpts["step"]
+
+ splats = {
+ name: torch.nn.Parameter(
+ torch.cat([ckpt["splats"][name] for ckpt in ckpts], dim=0)
+ )
+ for name in param_names
+ }
+
+
+ else:
+ raise ValueError(f"Unsupported file extension: {ext}")
+
+
+ return step, splats
+
+
+
+def save_splats(
+ path: str,
+ data: Dict[str, Union[Dict[str, torch.nn.Parameter], int]],
+ file_type: Literal['ply', 'pt', 'ckpt'] = 'ply',
+ step_value: Optional[int] = None
+) -> None:
+ """
+ Export splat data to .ply, .pt, or .ckpt format.
+
+ Args:
+ path (str): Output path without file extension.
+ data (dict): Dictionary containing 'splats' and optionally 'step'.
+ file_type (str): One of 'ply', 'pt', or 'ckpt'.
+ step_value (int, optional): If provided with file_type='pt', a .ckpt is also saved.
+ """
+ assert file_type in ['ply', 'pt', 'ckpt'], "Unsupported file type"
+
+ if file_type == 'ply':
+ path += ".ply"
+ splats = data["splats"]
+ print("Saving to .ply format")
+
+ export_data = {
+ "x": splats["means"][:, 0].detach().cpu().numpy(),
+ "y": splats["means"][:, 1].detach().cpu().numpy(),
+ "z": splats["means"][:, 2].detach().cpu().numpy(),
+ "opacity": splats["opacities"].detach().cpu().numpy(),
+ "rot_0": splats["quats"][:, 0].detach().cpu().numpy(),
+ "rot_1": splats["quats"][:, 1].detach().cpu().numpy(),
+ "rot_2": splats["quats"][:, 2].detach().cpu().numpy(),
+ "rot_3": splats["quats"][:, 3].detach().cpu().numpy(),
+ "scale_0": splats["scales"][:, 0].detach().cpu().numpy(),
+ "scale_1": splats["scales"][:, 1].detach().cpu().numpy(),
+ "scale_2": splats["scales"][:, 2].detach().cpu().numpy(),
+ }
+
+ for i in range(3):
+ export_data[f"f_dc_{i}"] = splats["sh0"][:, 0, i].detach().cpu().numpy()
+ for i in range(45):
+ c = i % 3
+ j = i // 3
+ export_data[f"f_rest_{i}"] = splats["shN"][:, j, c].detach().cpu().numpy()
+
+ df = pd.DataFrame(export_data)
+ cloud = PyntCloud(df)
+ cloud.to_file(path)
+
+ elif file_type == 'pt':
+ pt_path = path + ".pt"
+ print(f"Saving to .pt format at {pt_path}")
+ torch.save(data, pt_path)
+
+ if step_value is not None:
+ print("Also generating .ckpt file from .pt content")
+ _save_ckpt_from_splats(pt_path, path + ".ckpt", step_value)
+
+ elif file_type == 'ckpt':
+ print(f"Saving to .ckpt format at {path + '.ckpt'}")
+ _save_ckpt_from_splats(None, path + ".ckpt", step_value, direct_data=data)
+
+
+def _save_ckpt_from_splats(
+ pt_path: str,
+ save_path: str,
+ step_value: int,
+ direct_data: Optional[Dict[str, Dict[str, torch.nn.Parameter]]] = None
+) -> None:
+ """Internal helper to convert .pt or dict to .ckpt."""
+ device = torch.device("cpu")
+
+ if direct_data is not None:
+ splats = direct_data["splats"]
+ else:
+ pt_data = torch.load(pt_path, map_location=device)
+ splats = pt_data["splats"]
+
+ inverse_rename_map = {
+ "means": "_model.gauss_params.means",
+ "opacities": "_model.gauss_params.opacities",
+ "quats": "_model.gauss_params.quats",
+ "scales": "_model.gauss_params.scales",
+ "sh0": "_model.gauss_params.features_dc",
+ "shN": "_model.gauss_params.features_rest"
+ }
+
+ pipeline = {}
+ for new_key, old_key in inverse_rename_map.items():
+ tensor = splats[new_key].data
+ if new_key == "sh0":
+ tensor = tensor.squeeze(1)
+ elif new_key == "opacities":
+ tensor = tensor.unsqueeze(-1)
+ pipeline[old_key] = tensor
+
+ ckpt = {
+ "step": step_value,
+ "pipeline": pipeline
+ }
+
+ torch.save(ckpt, save_path)
+ print(f"Saved .ckpt to: {save_path}")
diff --git a/pruning_utils/traj.py b/pruning_utils/traj.py
new file mode 100644
index 0000000..9e8a2d6
--- /dev/null
+++ b/pruning_utils/traj.py
@@ -0,0 +1,254 @@
+"""
+Code borrowed from
+
+https://github.com/google-research/multinerf/blob/5b4d4f64608ec8077222c52fdf814d40acc10bc1/internal/camera_utils.py
+"""
+
+import numpy as np
+import scipy
+
+
+def normalize(x: np.ndarray) -> np.ndarray:
+ """Normalization helper function."""
+ return x / np.linalg.norm(x)
+
+
+def viewmatrix(lookdir: np.ndarray, up: np.ndarray, position: np.ndarray) -> np.ndarray:
+ """Construct lookat view matrix."""
+ vec2 = normalize(lookdir)
+ vec0 = normalize(np.cross(up, vec2))
+ vec1 = normalize(np.cross(vec2, vec0))
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
+ return m
+
+
+def focus_point_fn(poses: np.ndarray) -> np.ndarray:
+ """Calculate nearest point to all focal axes in poses."""
+ directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]
+ m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])
+ mt_m = np.transpose(m, [0, 2, 1]) @ m
+ focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]
+ return focus_pt
+
+
+def average_pose(poses: np.ndarray) -> np.ndarray:
+ """New pose using average position, z-axis, and up vector of input poses."""
+ position = poses[:, :3, 3].mean(0)
+ z_axis = poses[:, :3, 2].mean(0)
+ up = poses[:, :3, 1].mean(0)
+ cam2world = viewmatrix(z_axis, up, position)
+ return cam2world
+
+
+def generate_spiral_path(
+ poses,
+ bounds,
+ n_frames=120,
+ n_rots=2,
+ zrate=0.5,
+ spiral_scale_f=1.0,
+ spiral_scale_r=1.0,
+ focus_distance=0.75,
+):
+ """Calculates a forward facing spiral path for rendering."""
+ # Find a reasonable 'focus depth' for this dataset as a weighted average
+ # of conservative near and far bounds in disparity space.
+ near_bound = bounds.min()
+ far_bound = bounds.max()
+ # All cameras will point towards the world space point (0, 0, -focal).
+ focal = 1 / (((1 - focus_distance) / near_bound + focus_distance / far_bound))
+ focal = focal * spiral_scale_f
+
+ # Get radii for spiral path using 90th percentile of camera positions.
+ positions = poses[:, :3, 3]
+ radii = np.percentile(np.abs(positions), 90, 0)
+ radii = radii * spiral_scale_r
+ radii = np.concatenate([radii, [1.0]])
+
+ # Generate poses for spiral path.
+ render_poses = []
+ cam2world = average_pose(poses)
+ up = poses[:, :3, 1].mean(0)
+ for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=False):
+ t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
+ position = cam2world @ t
+ lookat = cam2world @ [0, 0, -focal, 1.0]
+ z_axis = position - lookat
+ render_poses.append(viewmatrix(z_axis, up, position))
+ render_poses = np.stack(render_poses, axis=0)
+ return render_poses
+
+
+def generate_ellipse_path_z(
+ poses: np.ndarray,
+ n_frames: int = 120,
+ # const_speed: bool = True,
+ variation: float = 0.0,
+ phase: float = 0.0,
+ height: float = 0.0,
+) -> np.ndarray:
+ """Generate an elliptical render path based on the given poses."""
+ # Calculate the focal point for the path (cameras point toward this).
+ center = focus_point_fn(poses)
+ # Path height sits at z=height (in middle of zero-mean capture pattern).
+ offset = np.array([center[0], center[1], height])
+
+ # Calculate scaling for ellipse axes based on input camera positions.
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
+ # Use ellipse that is symmetric about the focal point in xy.
+ low = -sc + offset
+ high = sc + offset
+ # Optional height variation need not be symmetric
+ z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
+ z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
+
+ def get_positions(theta):
+ # Interpolate between bounds with trig functions to get ellipse in x-y.
+ # Optionally also interpolate in z to change camera height along path.
+ return np.stack(
+ [
+ low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
+ low[1] + (high - low)[1] * (np.sin(theta) * 0.5 + 0.5),
+ variation
+ * (
+ z_low[2]
+ + (z_high - z_low)[2]
+ * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
+ )
+ + height,
+ ],
+ -1,
+ )
+
+ theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
+ positions = get_positions(theta)
+
+ # if const_speed:
+ # # Resample theta angles so that the velocity is closer to constant.
+ # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
+ # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
+ # positions = get_positions(theta)
+
+ # Throw away duplicated last position.
+ positions = positions[:-1]
+
+ # Set path's up vector to axis closest to average of input pose up vectors.
+ avg_up = poses[:, :3, 1].mean(0)
+ avg_up = avg_up / np.linalg.norm(avg_up)
+ ind_up = np.argmax(np.abs(avg_up))
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
+
+ return np.stack([viewmatrix(center - p, up, p) for p in positions])
+
+
+def generate_ellipse_path_y(
+ poses: np.ndarray,
+ n_frames: int = 120,
+ # const_speed: bool = True,
+ variation: float = 0.0,
+ phase: float = 0.0,
+ height: float = 0.0,
+) -> np.ndarray:
+ """Generate an elliptical render path based on the given poses."""
+ # Calculate the focal point for the path (cameras point toward this).
+ center = focus_point_fn(poses)
+ # Path height sits at y=height (in middle of zero-mean capture pattern).
+ offset = np.array([center[0], height, center[2]])
+
+ # Calculate scaling for ellipse axes based on input camera positions.
+ sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)
+ # Use ellipse that is symmetric about the focal point in xy.
+ low = -sc + offset
+ high = sc + offset
+ # Optional height variation need not be symmetric
+ y_low = np.percentile((poses[:, :3, 3]), 10, axis=0)
+ y_high = np.percentile((poses[:, :3, 3]), 90, axis=0)
+
+ def get_positions(theta):
+ # Interpolate between bounds with trig functions to get ellipse in x-z.
+ # Optionally also interpolate in y to change camera height along path.
+ return np.stack(
+ [
+ low[0] + (high - low)[0] * (np.cos(theta) * 0.5 + 0.5),
+ variation
+ * (
+ y_low[1]
+ + (y_high - y_low)[1]
+ * (np.cos(theta + 2 * np.pi * phase) * 0.5 + 0.5)
+ )
+ + height,
+ low[2] + (high - low)[2] * (np.sin(theta) * 0.5 + 0.5),
+ ],
+ -1,
+ )
+
+ theta = np.linspace(0, 2.0 * np.pi, n_frames + 1, endpoint=True)
+ positions = get_positions(theta)
+
+ # if const_speed:
+ # # Resample theta angles so that the velocity is closer to constant.
+ # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1)
+ # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1)
+ # positions = get_positions(theta)
+
+ # Throw away duplicated last position.
+ positions = positions[:-1]
+
+ # Set path's up vector to axis closest to average of input pose up vectors.
+ avg_up = poses[:, :3, 1].mean(0)
+ avg_up = avg_up / np.linalg.norm(avg_up)
+ ind_up = np.argmax(np.abs(avg_up))
+ up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])
+
+ return np.stack([viewmatrix(p - center, up, p) for p in positions])
+
+
+def generate_interpolated_path(
+ poses: np.ndarray,
+ n_interp: int,
+ spline_degree: int = 5,
+ smoothness: float = 0.03,
+ rot_weight: float = 0.1,
+):
+ """Creates a smooth spline path between input keyframe camera poses.
+
+ Spline is calculated with poses in format (position, lookat-point, up-point).
+
+ Args:
+ poses: (n, 3, 4) array of input pose keyframes.
+ n_interp: returned path will have n_interp * (n - 1) total poses.
+ spline_degree: polynomial degree of B-spline.
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
+ rot_weight: relative weighting of rotation/translation in spline solve.
+
+ Returns:
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
+ """
+
+ def poses_to_points(poses, dist):
+ """Converts from pose matrices to (position, lookat, up) format."""
+ pos = poses[:, :3, -1]
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
+ return np.stack([pos, lookat, up], 1)
+
+ def points_to_poses(points):
+ """Converts from (position, lookat, up) format to pose matrices."""
+ return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
+
+ def interp(points, n, k, s):
+ """Runs multidimensional B-spline interpolation on the input points."""
+ sh = points.shape
+ pts = np.reshape(points, (sh[0], -1))
+ k = min(k, sh[0] - 1)
+ tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
+ u = np.linspace(0, 1, n, endpoint=False)
+ new_points = np.array(scipy.interpolate.splev(u, tck))
+ new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
+ return new_points
+
+ points = poses_to_points(poses, dist=rot_weight)
+ new_points = interp(
+ points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
+ )
+ return points_to_poses(new_points)
diff --git a/pruning_utils/utils.py b/pruning_utils/utils.py
new file mode 100644
index 0000000..152b9ee
--- /dev/null
+++ b/pruning_utils/utils.py
@@ -0,0 +1,235 @@
+"""
+===========================================================================
+Unmodified code from gsplat examples
+Original source:
+https://github.com/nerfstudio-project/gsplat/blob/main/examples/utils.py
+License: Apache License 2.0
+----------------------------------------------------------------------------
+This file was copied verbatim from the gsplat repository. No modifications were made.
+===========================================================================
+"""
+
+import random
+
+import numpy as np
+import torch
+from sklearn.neighbors import NearestNeighbors
+from torch import Tensor
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+from matplotlib import colormaps
+
+
+class CameraOptModule(torch.nn.Module):
+ """Camera pose optimization module."""
+
+ def __init__(self, n: int):
+ super().__init__()
+ # Delta positions (3D) + Delta rotations (6D)
+ self.embeds = torch.nn.Embedding(n, 9)
+ # Identity rotation in 6D representation
+ self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]))
+
+ def zero_init(self):
+ torch.nn.init.zeros_(self.embeds.weight)
+
+ def random_init(self, std: float):
+ torch.nn.init.normal_(self.embeds.weight, std=std)
+
+ def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor:
+ """Adjust camera pose based on deltas.
+
+ Args:
+ camtoworlds: (..., 4, 4)
+ embed_ids: (...,)
+
+ Returns:
+ updated camtoworlds: (..., 4, 4)
+ """
+ assert camtoworlds.shape[:-2] == embed_ids.shape
+ batch_shape = camtoworlds.shape[:-2]
+ pose_deltas = self.embeds(embed_ids) # (..., 9)
+ dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:]
+ rot = rotation_6d_to_matrix(
+ drot + self.identity.expand(*batch_shape, -1)
+ ) # (..., 3, 3)
+ transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1))
+ transform[..., :3, :3] = rot
+ transform[..., :3, 3] = dx
+ return torch.matmul(camtoworlds, transform)
+
+
+class AppearanceOptModule(torch.nn.Module):
+ """Appearance optimization module."""
+
+ def __init__(
+ self,
+ n: int,
+ feature_dim: int,
+ embed_dim: int = 16,
+ sh_degree: int = 3,
+ mlp_width: int = 64,
+ mlp_depth: int = 2,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.sh_degree = sh_degree
+ self.embeds = torch.nn.Embedding(n, embed_dim)
+ layers = []
+ layers.append(
+ torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width)
+ )
+ layers.append(torch.nn.ReLU(inplace=True))
+ for _ in range(mlp_depth - 1):
+ layers.append(torch.nn.Linear(mlp_width, mlp_width))
+ layers.append(torch.nn.ReLU(inplace=True))
+ layers.append(torch.nn.Linear(mlp_width, 3))
+ self.color_head = torch.nn.Sequential(*layers)
+
+ def forward(
+ self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int
+ ) -> Tensor:
+ """Adjust appearance based on embeddings.
+
+ Args:
+ features: (N, feature_dim)
+ embed_ids: (C,)
+ dirs: (C, N, 3)
+
+ Returns:
+ colors: (C, N, 3)
+ """
+ from gsplat.cuda._torch_impl import _eval_sh_bases_fast
+
+ C, N = dirs.shape[:2]
+ # Camera embeddings
+ if embed_ids is None:
+ embeds = torch.zeros(C, self.embed_dim, device=features.device)
+ else:
+ embeds = self.embeds(embed_ids) # [C, D2]
+ embeds = embeds[:, None, :].expand(-1, N, -1) # [C, N, D2]
+ # GS features
+ features = features[None, :, :].expand(C, -1, -1) # [C, N, D1]
+ # View directions
+ dirs = F.normalize(dirs, dim=-1) # [C, N, 3]
+ num_bases_to_use = (sh_degree + 1) ** 2
+ num_bases = (self.sh_degree + 1) ** 2
+ sh_bases = torch.zeros(C, N, num_bases, device=features.device) # [C, N, K]
+ sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs)
+ # Get colors
+ if self.embed_dim > 0:
+ h = torch.cat([embeds, features, sh_bases], dim=-1) # [C, N, D1 + D2 + K]
+ else:
+ h = torch.cat([features, sh_bases], dim=-1)
+ colors = self.color_head(h)
+ return colors
+
+
+def rotation_6d_to_matrix(d6: Tensor) -> Tensor:
+ """
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
+ using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d.
+ Args:
+ d6: 6D rotation representation, of size (*, 6)
+
+ Returns:
+ batch of rotation matrices of size (*, 3, 3)
+
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
+ On the Continuity of Rotation Representations in Neural Networks.
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
+ Retrieved from http://arxiv.org/abs/1812.07035
+ """
+
+ a1, a2 = d6[..., :3], d6[..., 3:]
+ b1 = F.normalize(a1, dim=-1)
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
+ b2 = F.normalize(b2, dim=-1)
+ b3 = torch.cross(b1, b2, dim=-1)
+ return torch.stack((b1, b2, b3), dim=-2)
+
+
+def knn(x: Tensor, K: int = 4) -> Tensor:
+ x_np = x.cpu().numpy()
+ model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np)
+ distances, _ = model.kneighbors(x_np)
+ return torch.from_numpy(distances).to(x)
+
+
+def rgb_to_sh(rgb: Tensor) -> Tensor:
+ C0 = 0.28209479177387814
+ return (rgb - 0.5) / C0
+
+
+def set_random_seed(seed: int):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+# ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163
+def colormap(img, cmap="jet"):
+ W, H = img.shape[:2]
+ dpi = 300
+ fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi)
+ im = ax.imshow(img, cmap=cmap)
+ ax.set_axis_off()
+ fig.colorbar(im, ax=ax)
+ fig.tight_layout()
+ fig.canvas.draw()
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ img = torch.from_numpy(data).float().permute(2, 0, 1)
+ plt.close()
+ return img
+
+
+def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
+ """Convert single channel to a color img.
+
+ Args:
+ img (torch.Tensor): (..., 1) float32 single channel image.
+ colormap (str): Colormap for img.
+
+ Returns:
+ (..., 3) colored img with colors in [0, 1].
+ """
+ img = torch.nan_to_num(img, 0)
+ if colormap == "gray":
+ return img.repeat(1, 1, 3)
+ img_long = (img * 255).long()
+ img_long_min = torch.min(img_long)
+ img_long_max = torch.max(img_long)
+ assert img_long_min >= 0, f"the min value is {img_long_min}"
+ assert img_long_max <= 255, f"the max value is {img_long_max}"
+ return torch.tensor(
+ colormaps[colormap].colors, # type: ignore
+ device=img.device,
+ )[img_long[..., 0]]
+
+
+def apply_depth_colormap(
+ depth: torch.Tensor,
+ acc: torch.Tensor = None,
+ near_plane: float = None,
+ far_plane: float = None,
+) -> torch.Tensor:
+ """Converts a depth image to color for easier analysis.
+
+ Args:
+ depth (torch.Tensor): (..., 1) float32 depth.
+ acc (torch.Tensor | None): (..., 1) optional accumulation mask.
+ near_plane: Closest depth to consider. If None, use min image value.
+ far_plane: Furthest depth to consider. If None, use max image value.
+
+ Returns:
+ (..., 3) colored depth image with colors in [0, 1].
+ """
+ near_plane = near_plane or float(torch.min(depth))
+ far_plane = far_plane or float(torch.max(depth))
+ depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
+ depth = torch.clip(depth, 0.0, 1.0)
+ img = apply_float_colormap(depth, colormap="turbo")
+ if acc is not None:
+ img = img * acc + (1.0 - acc)
+ return img