Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 43 additions & 15 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp=False):
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, separate_sh = False, override_color = None, use_trained_exp=False, orthographic=True):
"""
Render the scene.

Expand All @@ -33,21 +33,48 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

# raster_settings = GaussianRasterizationSettings(
# image_height=int(viewpoint_camera.image_height),
# image_width=int(viewpoint_camera.image_width),
# tanfovx=tanfovx,
# tanfovy=tanfovy,
# bg=bg_color,
# scale_modifier=scaling_modifier,
# viewmatrix=viewpoint_camera.world_view_transform,
# projmatrix=viewpoint_camera.full_proj_transform,
# sh_degree=pc.active_sh_degree,
# campos=viewpoint_camera.camera_center,
# prefiltered=False,
# debug=pipe.debug,
# antialiasing=pipe.antialiasing
# )

# Set up rasterization configuration
if not orthographic:
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
full_proj_transform = viewpoint_camera.get_full_proj_transform(orthographic)
else:
tanfovx, tanfovy, full_proj_transform = viewpoint_camera.get_full_proj_transform(orthographic)


raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug,
antialiasing=pipe.antialiasing
)
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug,
antialiasing=pipe.antialiasing,
orthographic=orthographic
)


rasterizer = GaussianRasterizer(raster_settings=raster_settings)

Expand All @@ -71,6 +98,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
dc = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
Expand Down
16 changes: 16 additions & 0 deletions scene/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def __init__(self, resolution, colmap_id, R, T, FoVx, FoVy, depth_params, image,
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]

def get_full_proj_transform(self, orthographic=False):
if not orthographic:
return self.full_proj_transform
else:
tanfovx, tanfovy, projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, orthographic=True)
full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(projection_matrix.transpose(0,1).cuda().unsqueeze(0))).squeeze(0)
return tanfovx, tanfovy, full_proj_transform

class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
Expand All @@ -101,3 +109,11 @@ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform,
view_inv = torch.inverse(self.world_view_transform)
self.camera_center = view_inv[3][:3]

def get_full_proj_transform(self, orthographic=False):
if not orthographic:
return self.full_proj_transform
else:
tanfovx, tanfovy, projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy, orthographic=True)
full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(projection_matrix.transpose(0,1).cuda().unsqueeze(0))).squeeze(0)
return tanfovx, tanfovy, full_proj_transform

62 changes: 41 additions & 21 deletions utils/graphics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,47 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
Rt = np.linalg.inv(C2W)
return np.float32(Rt)

def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))

top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right

P = torch.zeros(4, 4)

z_sign = 1.0

P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def getProjectionMatrix(znear, zfar, fovX, fovY, orthographic=False):
if not orthographic:
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))

top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right

P = torch.zeros(4, 4)

z_sign = 1.0

P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)

return P

else:
left, right = -fovX, fovX
bottom, top = -fovY, fovY

P = torch.zeros(4, 4)

z_sign = 1.0
P[0, 0] = 2.0 / (right - left)
P[0, 3] = - (right + left) / (right - left)
P[1, 1] = 2.0 / (top - bottom)
P[1, 3] = - (top + bottom) / (top - bottom)
P[2, 2] = -2.0 / (zfar - znear)
P[2, 3] = - (zfar + znear) / (zfar - znear)
P[3, 3] = z_sign

# tanfovx, tanfovy, P
return (right - left) / 2, (top - bottom) / 2, P

def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
Expand Down