Skip to content

Commit 13c6602

Browse files
committed
Add depth visualization
1 parent 9f45a4f commit 13c6602

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

gaussian_renderer/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
8181
colors_precomp = override_color
8282

8383
# Rasterize visible Gaussians to image, obtain their radii (on screen).
84-
rendered_image, radii = rasterizer(
84+
rendered_image, radii, depth = rasterizer(
8585
means3D = means3D,
8686
means2D = means2D,
8787
shs = shs,
@@ -96,4 +96,5 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
9696
return {"render": rendered_image,
9797
"viewspace_points": screenspace_points,
9898
"visibility_filter" : radii > 0,
99-
"radii": radii}
99+
"radii": radii,
100+
"depth": depth}

render.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,22 @@
2424
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
2525
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
2626
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
27+
depth_path = os.path.join(model_path, name, "ours_{}".format(iteration), "depth")
2728

2829
makedirs(render_path, exist_ok=True)
2930
makedirs(gts_path, exist_ok=True)
31+
makedirs(depth_path, exist_ok=True)
3032

3133
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
32-
rendering = render(view, gaussians, pipeline, background)["render"]
34+
results = render(view, gaussians, pipeline, background)
35+
rendering = results["render"]
36+
depth = results["depth"]
37+
depth = depth / (depth.max() + 1e-5)
38+
3339
gt = view.original_image[0:3, :, :]
3440
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
3541
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
42+
torchvision.utils.save_image(depth, os.path.join(depth_path, '{0:05d}'.format(idx) + ".png"))
3643

3744
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
3845
with torch.no_grad():

0 commit comments

Comments
 (0)