Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions nerfstudio/data/dataparsers/mock_dataparser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Mock dataparser for sharing model checkpoints without requiring original training data.
Creates dummy camera poses and image paths for inference/viewing purposes only.
"""

from dataclasses import dataclass, field
from pathlib import Path
from typing import Type

import torch
import numpy as np

from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.data.dataparsers.base_dataparser import DataParser, DataParserConfig, DataparserOutputs
from nerfstudio.data.scene_box import SceneBox

Check failure on line 29 in nerfstudio/data/dataparsers/mock_dataparser.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/data/dataparsers/mock_dataparser.py:20:1: I001 Import block is un-sorted or un-formatted


@dataclass
class MockDataParserConfig(DataParserConfig):
"""Mock dataset config for inference without original data"""

_target: Type = field(default_factory=lambda: MockDataParser)
"""target class to instantiate"""
num_cameras: int = 100
"""number of dummy cameras to generate"""
image_height: int = 800
"""height of dummy images"""
image_width: int = 800
"""width of dummy images"""
focal_length: float = 800.0
"""focal length for dummy cameras"""
scene_scale: float = 1.0
"""scene scale"""


@dataclass
class MockDataParser(DataParser):
"""Mock DataParser that generates dummy data for inference/viewing without original training data"""

config: MockDataParserConfig

def _generate_dataparser_outputs(self, split="train"):
"""Generate mock dataparser outputs with dummy camera poses and image paths"""

# Generate dummy image filenames - these don't need to exist since we're only doing inference
image_filenames = [Path(f"mock_image_{i:04d}.jpg") for i in range(self.config.num_cameras)]

# Generate camera poses in a reasonable sphere around the scene
poses = self._generate_spherical_poses(self.config.num_cameras)

# Create camera intrinsics
fx = fy = self.config.focal_length
cx = self.config.image_width / 2.0
cy = self.config.image_height / 2.0

cameras = Cameras(
fx=torch.full((self.config.num_cameras,), fx),
fy=torch.full((self.config.num_cameras,), fy),
cx=torch.full((self.config.num_cameras,), cx),
cy=torch.full((self.config.num_cameras,), cy),
height=torch.full((self.config.num_cameras,), self.config.image_height),
width=torch.full((self.config.num_cameras,), self.config.image_width),
camera_to_worlds=poses,
camera_type=torch.full((self.config.num_cameras,), CameraType.PERSPECTIVE.value),
)

# Default scene box
scene_box = SceneBox(aabb=torch.tensor([[-self.config.scene_scale, -self.config.scene_scale, -self.config.scene_scale],
[self.config.scene_scale, self.config.scene_scale, self.config.scene_scale]]))

dataparser_outputs = DataparserOutputs(
image_filenames=image_filenames,
cameras=cameras,
scene_box=scene_box,
dataparser_transform=torch.eye(4)[:3, :],
dataparser_scale=1.0,
)

return dataparser_outputs

def _generate_spherical_poses(self, num_poses: int) -> torch.Tensor:
"""Generate camera poses distributed on a sphere looking at the origin"""
poses = []

# Generate poses on a sphere
for i in range(num_poses):
# Spherical coordinates
theta = 2 * np.pi * i / num_poses # azimuth
phi = np.pi / 4 # elevation (45 degrees)
radius = 4.0

# Convert to Cartesian
x = radius * np.sin(phi) * np.cos(theta)
y = radius * np.sin(phi) * np.sin(theta)
z = radius * np.cos(phi)

# Look at origin
camera_position = np.array([x, y, z])
look_at = np.array([0.0, 0.0, 0.0])
up = np.array([0.0, 0.0, 1.0])

# Create camera-to-world matrix
forward = look_at - camera_position
forward = forward / np.linalg.norm(forward)

right = np.cross(forward, up)
right = right / np.linalg.norm(right)

up_corrected = np.cross(right, forward)

pose = np.eye(4)
pose[:3, 0] = right
pose[:3, 1] = up_corrected
pose[:3, 2] = -forward # -forward for OpenCV convention
pose[:3, 3] = camera_position

poses.append(pose[:3, :4])

return torch.from_numpy(np.stack(poses)).float()
14 changes: 12 additions & 2 deletions nerfstudio/scripts/viewer/run_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,18 @@ def _start_viewer(config: TrainerConfig, pipeline: Pipeline, step: int):
pipeline: Pipeline instance of which to load weights
step: Step at which the pipeline was saved
"""
base_dir = config.get_base_dir()
viewer_log_path = base_dir / config.viewer.relative_log_filename
# Check if we're using a shared checkpoint (load_dir is outside the standard experiment structure)
try:
base_dir = config.get_base_dir()
# If get_base_dir() would create a path that doesn't exist, we're likely using a shared checkpoint
if not base_dir.parent.exists():
# Use the checkpoint directory as the base for log files
viewer_log_path = config.load_dir / config.viewer.relative_log_filename
else:
viewer_log_path = base_dir / config.viewer.relative_log_filename
except (FileNotFoundError, OSError):
# Fallback to using the checkpoint directory for shared checkpoints
viewer_log_path = config.load_dir / config.viewer.relative_log_filename
banner_messages = None
viewer_state = None
viewer_callback_lock = Lock()
Expand Down
47 changes: 45 additions & 2 deletions nerfstudio/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,56 @@
Evaluation utils
"""

from __future__ import annotations

import os
import sys
from pathlib import Path
from typing import Callable, Literal, Optional, Tuple

import torch
import yaml

from nerfstudio.configs.method_configs import all_methods
from nerfstudio.engine.trainer import TrainerConfig
from nerfstudio.data.dataparsers.mock_dataparser import MockDataParserConfig
from nerfstudio.pipelines.base_pipeline import Pipeline
from nerfstudio.utils.rich_utils import CONSOLE

Check failure on line 33 in nerfstudio/utils/eval_utils.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/utils/eval_utils.py:19:1: I001 Import block is un-sorted or un-formatted


def patch_config_for_mock_data(config: TrainerConfig) -> TrainerConfig:
"""Patch config to use mock data if original data path doesn't exist.

This allows viewing checkpoints without requiring the original training data.

Args:
config: The original trainer config

Returns:
Modified config that uses mock data if original data is not available
"""
# Determine the actual data path that will be used by the dataparser
dataparser_data = config.pipeline.datamanager.dataparser.data
datamanager_data = getattr(config.pipeline.datamanager, 'data', None)

# The dataparser will use its own data field if it's meaningful, otherwise it inherits from datamanager
# TODO: this is a hack, but I really need to change this
if dataparser_data and str(dataparser_data) != "." and dataparser_data.name != "":
actual_data_path = dataparser_data
else:
actual_data_path = datamanager_data

# Check if the actual data path exists
if not actual_data_path or not actual_data_path.exists():
CONSOLE.print(f"[yellow]Original data path {actual_data_path} not found. Using mock data for inference.[/yellow]")

# Replace the dataparser with MockDataParserConfig
config.pipeline.datamanager.dataparser = MockDataParserConfig()
CONSOLE.print("[green]Successfully switched to mock dataparser for inference.[/green]")

return config


def eval_load_checkpoint(config: TrainerConfig, pipeline: Pipeline) -> Tuple[Path, int]:
## TODO: ideally eventually want to get this to be the same as whatever is used to load train checkpoint too
"""Helper function to load checkpointed pipeline
Expand Down Expand Up @@ -97,9 +131,18 @@
if update_config_callback is not None:
config = update_config_callback(config)

# Patch config to use mock data if original data is not available
config = patch_config_for_mock_data(config)

# load checkpoints from wherever they were saved
# TODO: expose the ability to choose an arbitrary checkpoint
config.load_dir = config.get_checkpoint_dir()
# For shared checkpoints, the checkpoints should be relative to the config file location
config_dir = config_path.parent
expected_checkpoint_dir = config_dir / "nerfstudio_models"
if expected_checkpoint_dir.exists():
config.load_dir = expected_checkpoint_dir
else:
# Fallback to the original behavior
config.load_dir = config.get_checkpoint_dir()

# setup pipeline (which includes the DataManager)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
101 changes: 66 additions & 35 deletions nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,45 +444,76 @@ def init_scene(
# draw the training cameras and images
self.camera_handles: Dict[int, viser.CameraFrustumHandle] = {}
self.original_c2w: Dict[int, np.ndarray] = {}
image_indices = self._pick_drawn_image_idxs(len(train_dataset))
for idx in image_indices:
image = train_dataset[idx]["image"]
camera = train_dataset.cameras[idx]
image_uint8 = (image * 255).detach().type(torch.uint8)
image_uint8 = image_uint8.permute(2, 0, 1)

# torchvision can be slow to import, so we do it lazily.
import torchvision

image_uint8 = torchvision.transforms.functional.resize(image_uint8, 100, antialias=None) # type: ignore
image_uint8 = image_uint8.permute(1, 2, 0)
image_uint8 = image_uint8.cpu().numpy()
c2w = camera.camera_to_worlds.cpu().numpy()
R = vtf.SO3.from_matrix(c2w[:3, :3])
R = R @ vtf.SO3.from_x_radians(np.pi)
camera_handle = self.viser_server.scene.add_camera_frustum(
name=f"/cameras/camera_{idx:05d}",
fov=float(2 * np.arctan((camera.cx / camera.fx[0]).cpu())),
scale=self.config.camera_frustum_scale,
aspect=float((camera.cx[0] / camera.cy[0]).cpu()),
image=image_uint8,
wxyz=R.wxyz,
position=c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO,
)

# Check if we're using mock data - if so, skip image loading to avoid file errors
is_mock_data = (
len(train_dataset) > 0 and
hasattr(train_dataset, '_dataparser_outputs') and
len(train_dataset._dataparser_outputs.image_filenames) > 0 and
str(train_dataset._dataparser_outputs.image_filenames[0]).startswith("mock_image")
)

if is_mock_data:
# For mock data, just draw camera frustums without images
image_indices = self._pick_drawn_image_idxs(len(train_dataset))
for idx in image_indices:
camera = train_dataset.cameras[idx]
c2w = camera.camera_to_worlds.cpu().numpy()
R = vtf.SO3.from_matrix(c2w[:3, :3])
R = R @ vtf.SO3.from_x_radians(np.pi)
camera_handle = self.viser_server.scene.add_camera_frustum(
name=f"/cameras/camera_{idx:05d}",
fov=2 * np.arctan(camera.height / (2 * camera.fy)).item(),
aspect=camera.width / camera.height,
scale=0.1,
color=(255, 255, 255),
wxyz=R.wxyz,
position=c2w[:3, 3],
visible=False,
)
self.camera_handles[idx] = camera_handle
self.original_c2w[idx] = c2w
else:
# Normal image loading for real data
image_indices = self._pick_drawn_image_idxs(len(train_dataset))
for idx in image_indices:
image = train_dataset[idx]["image"]
camera = train_dataset.cameras[idx]
image_uint8 = (image * 255).detach().type(torch.uint8)
image_uint8 = image_uint8.permute(2, 0, 1)

# torchvision can be slow to import, so we do it lazily.
import torchvision

image_uint8 = torchvision.transforms.functional.resize(image_uint8, 100, antialias=None) # type: ignore
image_uint8 = image_uint8.permute(1, 2, 0)
image_uint8 = image_uint8.cpu().numpy()
c2w = camera.camera_to_worlds.cpu().numpy()
R = vtf.SO3.from_matrix(c2w[:3, :3])
R = R @ vtf.SO3.from_x_radians(np.pi)
camera_handle = self.viser_server.scene.add_camera_frustum(
name=f"/cameras/camera_{idx:05d}",
fov=float(2 * np.arctan((camera.cx / camera.fx[0]).cpu())),
scale=self.config.camera_frustum_scale,
aspect=float((camera.cx[0] / camera.cy[0]).cpu()),
image=image_uint8,
wxyz=R.wxyz,
position=c2w[:3, 3] * VISER_NERFSTUDIO_SCALE_RATIO,
)

def create_on_click_callback(capture_idx):
def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None:
with event.client.atomic():
event.client.camera.position = event.target.position
event.client.camera.wxyz = event.target.wxyz
self.current_camera_idx = capture_idx
def create_on_click_callback(capture_idx):
def on_click_callback(event: viser.SceneNodePointerEvent[viser.CameraFrustumHandle]) -> None:
with event.client.atomic():
event.client.camera.position = event.target.position
event.client.camera.wxyz = event.target.wxyz
self.current_camera_idx = capture_idx

return on_click_callback
return on_click_callback

camera_handle.on_click(create_on_click_callback(idx))
camera_handle.on_click(create_on_click_callback(idx))

self.camera_handles[idx] = camera_handle
self.original_c2w[idx] = c2w
self.camera_handles[idx] = camera_handle
self.original_c2w[idx] = c2w

self.train_state = train_state
self.train_util = 0.9
Expand Down
Loading