Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,71 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
}

return out




##################### 추가함수 #####################
def render_reflected_gaussians(viewpoint_camera, pc: GaussianModel, pipe, bg_color: torch.Tensor, mirror_transform: torch.Tensor, gaussians_to_reflect_mask: torch.Tensor, scaling_modifier=1.0):
reflected_attrs = pc.reflect(mirror_transform, gaussians_to_reflect_mask)

if reflected_attrs is None or reflected_attrs["xyz"].shape[0] == 0:
image_height = int(viewpoint_camera.image_height)
image_width = int(viewpoint_camera.image_width)
return {"render": torch.full((3, image_height, image_width), bg_color[0].item(), device="cuda"), "radii": torch.zeros(0, device="cuda")}

tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

# 카메라 시선벡터 반사
orig_campos = viewpoint_camera.camera_center
campos_hom = torch.cat([orig_campos, torch.ones(1, device="cuda")], dim=0)
reflected_campos_hom = campos_hom @ mirror_transform.T
reflected_campos = reflected_campos_hom[:3]

raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=reflected_campos,
prefiltered=False,
debug=pipe.debug,
antialiasing=pipe.antialiasing
)

rasterizer = GaussianRasterizer(raster_settings=raster_settings)

means3D = reflected_attrs["xyz"]
rotations = torch.nn.functional.normalize(reflected_attrs["rotation"])
scales = torch.exp(reflected_attrs["scaling"])
opacity = torch.sigmoid(reflected_attrs["opacity"])
shs = torch.cat((reflected_attrs["features_dc"], reflected_attrs["features_rest"]), dim=1)

screenspace_points = torch.zeros_like(means3D, dtype=means3D.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass

# Rasterize
rendered_image, radii, _ = rasterizer(
means3D = means3D,
means2D = screenspace_points,
shs = shs,
colors_precomp = None,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = None
)

return {
"render": rendered_image.clamp(0, 1),
"radii": radii
}
57 changes: 57 additions & 0 deletions propose_mirror_plane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import os
import json
from argparse import ArgumentParser
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '.')))

from scene.colmap_loader import read_points3D_binary

def propose_plane_robust(points_bin_path: str, output_json_path: str):
print(f"Analyzing point cloud from: {points_bin_path}")
try:
xyz, _, _ = read_points3D_binary(points_bin_path)
if xyz.shape[0] == 0:
raise ValueError("Point cloud is empty.")

y_coords = xyz[:, 1]

# 모든 가우시안의 y값의 하위 5%를 실질적 바닥으로 둠
y_percentile_5 = np.percentile(y_coords, 5)
y_mean = np.mean(y_coords)

print(f"Point cloud Y-axis 5th percentile: {y_percentile_5:.3f}")
print(f"Point cloud Y-axis mean: {y_mean:.3f}")

# 바닥에서 y값 평균 ~ 바닥 거리 만큼 뺌 -> 바닥 아래에 반사 평면 위치시킴
plane_y_value = y_percentile_5 - (y_mean - y_percentile_5)

# 법선벡터는 (0, 1, 0)으로 고정
plane_params = {
"a": 0.0,
"b": 1.0,
"c": 0.0,
"d": -plane_y_value
}

os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
with open(output_json_path, 'w') as f:
json.dump(plane_params, f, indent=4)

print(f"\nSuccessfully proposed robust mirror plane: y = {plane_y_value:.3f}")
print(f"Plane parameters saved to: {output_json_path}")

except Exception as e:
print(f"Error: Failed to propose mirror plane. {e}")

if __name__ == "__main__":
parser = ArgumentParser(description="Propose a robust mirror plane from a COLMAP sparse point cloud.")
parser.add_argument("-s", "--source_path", required=True, type=str, help="Path to the COLMAP dataset directory")
parser.add_argument("-m", "--model_path", required=True, type=str, help="Path to the output model directory where the plane file will be saved")
args = parser.parse_args()

points_3d_bin = os.path.join(args.source_path, "sparse/0/points3D.bin")
output_json = os.path.join(args.model_path, "mirror_plane.json")

propose_plane_robust(points_3d_bin, output_json)
67 changes: 57 additions & 10 deletions scene/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from plyfile import PlyData, PlyElement
from utils.sh_utils import RGB2SH
from simple_knn._C import distCUDA2
from utils.graphics_utils import BasicPointCloud
from utils.general_utils import strip_symmetric, build_scaling_rotation
from utils.graphics_utils import BasicPointCloud, geom_transform_points
from utils.general_utils import strip_symmetric, build_scaling_rotation, quat_from_matrix, build_rotation

try:
from diff_gaussian_rasterization import SparseGaussianAdam
Expand Down Expand Up @@ -318,14 +318,18 @@ def replace_tensor_to_optimizer(self, tensor, name):
for group in self.optimizer.param_groups:
if group["name"] == name:
stored_state = self.optimizer.state.get(group['params'][0], None)
stored_state["exp_avg"] = torch.zeros_like(tensor)
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state

optimizable_tensors[group["name"]] = group["params"][0]
if stored_state is not None:
stored_state["exp_avg"] = torch.zeros_like(tensor)
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state

optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors

def _prune_optimizer(self, mask):
Expand Down Expand Up @@ -471,3 +475,46 @@ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, radi
def add_densification_stats(self, viewspace_point_tensor, update_filter):
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
self.denom[update_filter] += 1




##################### 추가함수 #####################
@torch.no_grad()
def reflect(self, mirror_transform, reflect_mask):
from utils.graphics_utils import geom_transform_points
from utils.general_utils import build_rotation

if reflect_mask.sum() == 0:
return None

object_rotation_q = self.get_rotation[reflect_mask]
object_xyz = self.get_xyz[reflect_mask]

# 센터 반사
reflected_xyz = geom_transform_points(object_xyz, mirror_transform.T)

# 회전 반사
H_3x3 = mirror_transform[:3, :3]
R = build_rotation(object_rotation_q)
R_reflected = H_3x3 @ R
# 재직교화(SVD) + det(+1) 강제
U, S, Vt = torch.linalg.svd(R_reflected) # (N,3,3)
R_reflected = U @ Vt
det = torch.det(R_reflected)
neg = det < 0
if neg.any():
Vt[neg, -1, :] *= -1.0
R_reflected = U @ Vt

reflected_rotation_q = quat_from_matrix(R_reflected)

# 반사시킨 속성 반환
return {
"xyz": reflected_xyz,
"rotation": torch.nn.functional.normalize(reflected_rotation_q, p=2, dim=1),
"scaling": self._scaling[reflect_mask],
"opacity": self._opacity[reflect_mask],
"features_dc": self._features_dc[reflect_mask],
"features_rest": self._features_rest[reflect_mask]
}
Loading