Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/base_gs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ resume: ""
out_dir: "./runs"
n_iterations: 30000
with_gui: false
with_viser_gui: false
gui_update_from_device: true
val_frequency: 5000
num_workers: 24
Expand Down
23 changes: 21 additions & 2 deletions threedgrut/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import defaultdict
from pathlib import Path
from typing import Any, Optional, Union

import time
import numpy as np

import torch
Expand All @@ -41,6 +41,7 @@
from threedgrut.render import Renderer
from threedgrut.strategy.base import BaseStrategy
from threedgrut.utils.gui import GUI
from threedgrut.utils.viser_gui_util import ViserGUI
from threedgrut.utils.logger import logger
from threedgrut.utils.timer import CudaTimer
from threedgrut.utils.misc import jet_map, create_summary_writer, check_step_condition
Expand Down Expand Up @@ -287,6 +288,8 @@ def init_gui(
gui = None
if conf.with_gui:
gui = GUI(conf, model, train_dataset, val_dataset, scene_bbox)
elif conf.with_viser_gui:
gui = ViserGUI(conf, model, train_dataset, val_dataset, scene_bbox)
self.gui = gui

def init_metrics(self):
Expand Down Expand Up @@ -689,6 +692,18 @@ def render_gui(self, scene_updated):
if ps.window_requests_close():
logger.warning("Terminating training from GUI window is not supported. Please terminate it from the terminal.")

def render_gui_viser(self, scene_updated):
gui = self.gui
if gui is not None:
if gui.live_update:
# update render view
if scene_updated or self.model.positions.requires_grad:
gui.update_point_cloud()
for client in gui.server.get_clients().values():
gui.update_render_view(client, force=True)
while not gui.viz_do_train:
time.sleep(0.0001)

@torch.cuda.nvtx.range(f"run_train_pass")
def run_train_pass(self, conf: DictConfig):
"""Runs a single train epoch over the dataset."""
Expand Down Expand Up @@ -798,7 +813,11 @@ def run_train_pass(self, conf: DictConfig):
self.save_checkpoint()

with torch.cuda.nvtx.range(f"train_{global_step-1}_update_gui"):
self.render_gui(scene_updated) # Updating the GUI
# self.render_gui(scene_updated) # Updating the GUI
if self.conf.with_viser_gui:
self.render_gui_viser(scene_updated)
elif self.conf.with_gui:
self.render_gui(scene_updated)

self.log_training_pass(metrics)

Expand Down
Loading