-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain.py
More file actions
1743 lines (1528 loc) · 102 KB
/
train.py
File metadata and controls
1743 lines (1528 loc) · 102 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import array
import json
import os
import sys
from argparse import ArgumentParser
from collections import defaultdict
from random import randint
import Imath
import OpenEXR
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from tqdm import tqdm
from arguments import ModelParams, OptimizationParams, PipelineParams
from gaussian_renderer import render_fn_dict
from scene import Scene, GaussianModel
from scene.direct_light_map import DirectLightMap
from scene.mlp import BaseColorMLP, CombinedMLP, MetallicMLP, NormalMLP, RoughnessMLP
from torchvision.utils import make_grid, save_image
from utils.general_utils import safe_state
from utils.graphics_utils import load_and_transform_diffusion_renderer_image, rgb_to_srgb
from utils.loss_utils import l1_loss, scale_invariant_l1_loss
from utils.system_utils import prepare_output_and_logger
from utils.training.gradient_logging import log_mlp_gradients
from utils.training.intersection_tracing import (
load_intersection_data,
perform_intersection_tracing_for_all_training_images,
)
from utils.training.normal_utils import get_camera_to_world_rotation_matrix, load_normal_image
from utils.training.render_debug import create_timelapse_video, render_timelapse_frame
from utils.training.reporting import eval_render, save_training_vis, training_report
def training(dataset: ModelParams, opt: OptimizationParams, pipe: PipelineParams, is_pbr=False):
first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
"""
Setup Gaussians
"""
gaussians = GaussianModel(dataset.sh_degree, render_type=args.type)
scene = Scene(dataset, gaussians)
if args.checkpoint:
print("Create Gaussians from checkpoint {}".format(args.checkpoint))
first_iter = gaussians.create_from_ckpt(args.checkpoint, restore_optimizer=True)
# Load MLP data from checkpoint if MLPs are enabled
if is_pbr and args.use_mlp:
checkpoint_dir = os.path.dirname(args.checkpoint)
mlp_data_dir = os.path.join(checkpoint_dir, "mlp_data")
if os.path.exists(mlp_data_dir):
try:
# Load MLP metadata
import json
with open(os.path.join(mlp_data_dir, "mlp_metadata.json"), 'r') as f:
mlp_metadata = json.load(f)
if mlp_metadata.get('has_mlps', False):
# Load projected data
projected_base_colors = np.load(os.path.join(mlp_data_dir, "projected_base_colors.npy"))
projected_roughness = np.load(os.path.join(mlp_data_dir, "projected_roughness.npy"))
projected_metallic = np.load(os.path.join(mlp_data_dir, "projected_metallic.npy"))
# Convert to tensors and store (not trainable)
gaussians.projected_base_colors = torch.tensor(projected_base_colors, dtype=torch.float, device="cuda", requires_grad=False)
gaussians.projected_roughness = torch.tensor(projected_roughness, dtype=torch.float, device="cuda", requires_grad=False)
gaussians.projected_metallic = torch.tensor(projected_metallic, dtype=torch.float, device="cuda", requires_grad=False)
# Load normal MLP data if it exists and normals_folder is provided
if mlp_metadata.get('has_normal_mlp', False) and getattr(args, 'normals_folder', None) is not None:
try:
projected_normals = np.load(os.path.join(mlp_data_dir, "projected_normals.npy"))
gaussians.projected_normals = torch.tensor(projected_normals, dtype=torch.float, device="cuda", requires_grad=False)
print(f" - Normal MLP data loaded, shape: {gaussians.projected_normals.shape}")
except Exception as e:
print(f"Warning: Failed to load normal MLP data: {e}")
else:
print(" - Skipping normal MLP data load (no normals_folder provided)")
# Initialize MLPs
num_train_images = mlp_metadata['num_train_images']
net_width = mlp_metadata.get('net_width', 64)
# Check if using combined MLP - prefer args flag, fallback to metadata
use_combined_mlp = getattr(args, 'combined_mlp', False) or mlp_metadata.get('use_combined_mlp', False)
gaussians.use_combined_mlp = use_combined_mlp
if use_combined_mlp:
gaussians.combined_mlp = CombinedMLP(
num_train_images=num_train_images, net_width=net_width,
enable_gradient_logging=args.enable_gradient_logging
).cuda()
gaussians.base_color_mlp = None
gaussians.roughness_mlp = None
gaussians.metallic_mlp = None
else:
gaussians.base_color_mlp = BaseColorMLP(
num_train_images=num_train_images, net_width=net_width,
enable_gradient_logging=args.enable_gradient_logging
).cuda()
gaussians.roughness_mlp = RoughnessMLP(
num_train_images=num_train_images, net_width=net_width,
enable_gradient_logging=args.enable_gradient_logging
).cuda()
gaussians.metallic_mlp = MetallicMLP(
num_train_images=num_train_images, net_width=net_width,
enable_gradient_logging=args.enable_gradient_logging
).cuda()
# Initialize normal MLP if data exists and normals_folder provided
if mlp_metadata.get('has_normal_mlp', False) and getattr(args, 'normals_folder', None) is not None:
gaussians.normal_mlp = NormalMLP(
num_train_images=num_train_images, net_width=net_width,
enable_gradient_logging=args.enable_gradient_logging
).cuda()
# Load MLP weights if they exist
try:
if use_combined_mlp:
gaussians.combined_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "combined_mlp_weights.pth")), strict=False)
else:
gaussians.base_color_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "base_color_mlp_weights.pth")), strict=False)
gaussians.roughness_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "roughness_mlp_weights.pth")), strict=False)
gaussians.metallic_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "metallic_mlp_weights.pth")), strict=False)
if mlp_metadata.get('has_normal_mlp', False) and getattr(args, 'normals_folder', None) is not None:
gaussians.normal_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "normal_mlp_weights.pth")), strict=False)
print("✓ MLP weights loaded successfully")
except Exception as e:
print(f"Warning: Failed to load MLP weights: {e}")
print("MLPs will be initialized with random weights")
print(f"✓ MLP projected data and weights loaded from checkpoint {mlp_data_dir}")
print(f" - Number of training images: {num_train_images}")
print(f" - Network width: {net_width}")
print(f" - Using combined MLP: {use_combined_mlp}")
print(f" - Projected data shapes: {gaussians.projected_base_colors.shape}, {gaussians.projected_roughness.shape}, {gaussians.projected_metallic.shape}")
if mlp_metadata.get('has_normal_mlp', False):
print(f" - Normal MLP enabled")
except Exception as e:
print(f"Warning: Failed to load MLP data from checkpoint {mlp_data_dir}: {e}")
print("Model will be loaded without MLPs")
elif scene.loaded_iter:
gaussians.load_ply(os.path.join(dataset.model_path,
"point_cloud",
"iteration_" + str(scene.loaded_iter),
"point_cloud.ply"))
# Always attach MLPs and projected data if is_pbr and args.use_mlp
if is_pbr and args.use_mlp:
ply_dir = os.path.dirname(os.path.join(dataset.model_path,
"point_cloud",
"iteration_" + str(scene.loaded_iter),
"point_cloud.ply"))
mlp_data_dir = os.path.join(ply_dir, "mlp_data")
if os.path.exists(mlp_data_dir):
try:
import json
with open(os.path.join(mlp_data_dir, "mlp_metadata.json"), 'r') as f:
mlp_metadata = json.load(f)
if mlp_metadata.get('has_mlps', False):
projected_base_colors = np.load(os.path.join(mlp_data_dir, "projected_base_colors.npy"))
projected_roughness = np.load(os.path.join(mlp_data_dir, "projected_roughness.npy"))
projected_metallic = np.load(os.path.join(mlp_data_dir, "projected_metallic.npy"))
gaussians.projected_base_colors = torch.tensor(projected_base_colors, dtype=torch.float, device="cuda", requires_grad=False)
gaussians.projected_roughness = torch.tensor(projected_roughness, dtype=torch.float, device="cuda", requires_grad=False)
gaussians.projected_metallic = torch.tensor(projected_metallic, dtype=torch.float, device="cuda", requires_grad=False)
# Load normal MLP data if it exists and normals_folder is provided
if mlp_metadata.get('has_normal_mlp', False) and getattr(args, 'normals_folder', None) is not None:
try:
projected_normals = np.load(os.path.join(mlp_data_dir, "projected_normals.npy"))
gaussians.projected_normals = torch.tensor(projected_normals, dtype=torch.float, device="cuda", requires_grad=False)
print(f" - Normal MLP data loaded, shape: {gaussians.projected_normals.shape}")
except Exception as e:
print(f"Warning: Failed to load normal MLP data: {e}")
else:
print(" - Skipping normal MLP data load (no normals_folder provided)")
num_train_images = mlp_metadata['num_train_images']
net_width = mlp_metadata.get('net_width', 64)
# Check if using combined MLP - prefer args flag, fallback to metadata
use_combined_mlp = getattr(args, 'combined_mlp', False) or mlp_metadata.get('use_combined_mlp', False)
gaussians.use_combined_mlp = use_combined_mlp
if use_combined_mlp:
gaussians.combined_mlp = CombinedMLP(num_train_images=num_train_images, net_width=net_width, enable_gradient_logging=args.enable_gradient_logging).cuda()
gaussians.base_color_mlp = None
gaussians.roughness_mlp = None
gaussians.metallic_mlp = None
else:
gaussians.base_color_mlp = BaseColorMLP(num_train_images=num_train_images, net_width=net_width, enable_gradient_logging=args.enable_gradient_logging).cuda()
gaussians.roughness_mlp = RoughnessMLP(num_train_images=num_train_images, net_width=net_width, enable_gradient_logging=args.enable_gradient_logging).cuda()
gaussians.metallic_mlp = MetallicMLP(num_train_images=num_train_images, net_width=net_width, enable_gradient_logging=args.enable_gradient_logging).cuda()
# Initialize normal MLP if data exists and normals_folder provided
if mlp_metadata.get('has_normal_mlp', False) and getattr(args, 'normals_folder', None) is not None:
gaussians.normal_mlp = NormalMLP(num_train_images=num_train_images, net_width=net_width, enable_gradient_logging=args.enable_gradient_logging).cuda()
try:
if use_combined_mlp:
gaussians.combined_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "combined_mlp_weights.pth")), strict=False)
else:
gaussians.base_color_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "base_color_mlp_weights.pth")), strict=False)
gaussians.roughness_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "roughness_mlp_weights.pth")), strict=False)
gaussians.metallic_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "metallic_mlp_weights.pth")), strict=False)
if mlp_metadata.get('has_normal_mlp', False) and getattr(args, 'normals_folder', None) is not None:
gaussians.normal_mlp.load_state_dict(torch.load(os.path.join(mlp_data_dir, "normal_mlp_weights.pth")), strict=False)
print("✓ MLP weights loaded successfully (post-ply)")
except Exception as e:
print(f"Warning: Failed to load MLP weights (post-ply): {e}")
print("MLPs will be initialized with random weights (post-ply)")
print(f"✓ MLP projected data and weights loaded from {mlp_data_dir} (post-ply)")
print(f" - Number of training images: {num_train_images}")
print(f" - Network width: {net_width}")
print(f" - Using combined MLP: {use_combined_mlp}")
print(f" - Projected data shapes: {gaussians.projected_base_colors.shape}, {gaussians.projected_roughness.shape}, {gaussians.projected_metallic.shape}")
if mlp_metadata.get('has_normal_mlp', False):
print(f" - Normal MLP enabled")
gaussians.quick_debug_mlp_status()
except Exception as e:
print(f"Warning: Failed to load MLP data from {mlp_data_dir} (post-ply): {e}")
print("Model will be loaded without MLPs (post-ply)")
else:
# Prefer scene-provided point cloud, but fall back to dataset points3d.ply if needed
ply_path = os.path.join(dataset.source_path, "points3d.ply")
if scene.scene_info is not None and scene.scene_info.point_cloud is not None:
gaussians.create_from_pcd(scene.scene_info.point_cloud, scene.cameras_extent)
elif os.path.isfile(ply_path):
print(f"Loading point cloud from {ply_path}")
gaussians.load_ply(ply_path)
else:
raise FileNotFoundError(f"No valid point cloud found; expected {ply_path}")
"""
Setup MLPs for base_color, roughness, and metallic (BEFORE training setup)
"""
if is_pbr and args.use_mlp:
print("Initializing MLPs for base_color, roughness, and metallic...")
# Get number of training images
num_train_images = len(scene.getTrainCameras())
print(f"Number of training images: {num_train_images}")
# Check if using combined MLP or separate MLPs
use_combined_mlp = getattr(args, 'combined_mlp', False)
gaussians.use_combined_mlp = use_combined_mlp
if use_combined_mlp:
print("Using combined MLP for base_color, roughness, and metallic")
combined_mlp = CombinedMLP(num_train_images=num_train_images, net_width=64, enable_gradient_logging=args.enable_gradient_logging).cuda()
# Load combined MLP weights if they exist
if hasattr(gaussians, 'mlp_weights_path'):
try:
combined_mlp.load_state_dict(torch.load(os.path.join(gaussians.mlp_weights_path, "combined_mlp_weights.pth")), strict=False)
print("✓ Combined MLP weights loaded successfully")
except Exception as e:
print(f"Warning: Failed to load combined MLP weights: {e}")
print("Combined MLP will be initialized with random weights")
gaussians.combined_mlp = combined_mlp
# Set separate MLPs to None to indicate they're not used
gaussians.base_color_mlp = None
gaussians.roughness_mlp = None
gaussians.metallic_mlp = None
else:
print("Using separate MLPs for base_color, roughness, and metallic")
# Initialize MLPs with the correct number of training images and enable gradient logging
base_color_mlp = BaseColorMLP(num_train_images=num_train_images, net_width=64, enable_gradient_logging=args.enable_gradient_logging).cuda()
roughness_mlp = RoughnessMLP(num_train_images=num_train_images, net_width=64, enable_gradient_logging=args.enable_gradient_logging).cuda()
metallic_mlp = MetallicMLP(num_train_images=num_train_images, net_width=64, enable_gradient_logging=args.enable_gradient_logging).cuda()
# Load MLP weights if they exist
if hasattr(gaussians, 'mlp_weights_path'):
try:
base_color_mlp.load_state_dict(torch.load(os.path.join(gaussians.mlp_weights_path, "base_color_mlp_weights.pth")), strict=False)
roughness_mlp.load_state_dict(torch.load(os.path.join(gaussians.mlp_weights_path, "roughness_mlp_weights.pth")), strict=False)
metallic_mlp.load_state_dict(torch.load(os.path.join(gaussians.mlp_weights_path, "metallic_mlp_weights.pth")), strict=False)
print("✓ MLP weights loaded successfully")
except Exception as e:
print(f"Warning: Failed to load MLP weights: {e}")
print("MLPs will be initialized with random weights")
gaussians.base_color_mlp = base_color_mlp
gaussians.roughness_mlp = roughness_mlp
gaussians.metallic_mlp = metallic_mlp
# Normal MLP is always separate (not part of combined MLP)
normal_mlp = None
if getattr(args, 'normals_folder', None) is not None:
normal_mlp = NormalMLP(num_train_images=num_train_images, net_width=64, enable_gradient_logging=args.enable_gradient_logging).cuda()
if hasattr(gaussians, 'mlp_weights_path'):
try:
normal_mlp.load_state_dict(torch.load(os.path.join(gaussians.mlp_weights_path, "normal_mlp_weights.pth")), strict=False)
except Exception as e:
print(f"Warning: Failed to load normal MLP weights: {e}")
if normal_mlp is not None:
gaussians.normal_mlp = normal_mlp
# Create projected input data for MLPs (num_train_images values per gaussian)
num_gaussians = gaussians.get_xyz.shape[0]
# Initialize with reasonable values instead of zeros
# Base colors: initialize with white (1,1,1) for each training image
projected_base_colors = torch.ones((num_gaussians, num_train_images, 3), device='cuda', requires_grad=False) # [num_gaussians, num_train_images, 3]
# Roughness: initialize with medium roughness (0.5) for each training image
projected_roughness = torch.full((num_gaussians, num_train_images, 1), 0.5, device='cuda', requires_grad=False) # [num_gaussians, num_train_images, 1]
# Metallic: initialize with non-metallic (0.0) for each training image
projected_metallic = torch.zeros((num_gaussians, num_train_images, 1), device='cuda', requires_grad=False) # [num_gaussians, num_train_images, 1]
# Normals: initialize with up direction (0,1,0) for each training image
projected_normals = torch.tensor([0.0, 1.0, 0.0], device='cuda', requires_grad=False).expand(num_gaussians, num_train_images, 3) # [num_gaussians, num_train_images, 3]
# Store projected data in gaussians for access during training
gaussians.projected_base_colors = projected_base_colors
gaussians.projected_roughness = projected_roughness
gaussians.projected_metallic = projected_metallic
if normal_mlp is not None:
gaussians.projected_normals = projected_normals
# Store relevant flags for model metadata
gaussians.perform_intersection_tracing = args.perform_intersection_tracing
gaussians.load_intersection_data = args.load_intersection_data
print(f"✓ MLPs initialized for {num_gaussians} Gaussians")
if use_combined_mlp:
print(f" - Using combined MLP")
else:
print(f" - Using separate MLPs")
# Debug MLP outputs to check initialization
gaussians.debug_mlp_outputs()
elif is_pbr:
print("MLPs disabled - using direct parameter optimization")
# Store relevant flags for model metadata
gaussians.perform_intersection_tracing = args.perform_intersection_tracing
gaussians.load_intersection_data = args.load_intersection_data
gaussians.training_setup(opt)
"""
Setup PBR components
"""
pbr_kwargs = dict()
if is_pbr:
# first update visibility
gaussians.update_visibility(pipe.sample_num)
pbr_kwargs['sample_num'] = pipe.sample_num
print("Using global incident light for regularization.")
direct_env_light = DirectLightMap(dataset.env_resolution, opt.light_init)
if args.checkpoint:
env_checkpoint = os.path.dirname(args.checkpoint) + "/env_light_" + os.path.basename(args.checkpoint)
print("Trying to load global incident light from ", env_checkpoint)
if os.path.exists(env_checkpoint):
direct_env_light.create_from_ckpt(env_checkpoint, restore_optimizer=True)
print("Successfully loaded!")
else:
print("Failed to load!")
direct_env_light.training_setup(opt)
pbr_kwargs["env_light"] = direct_env_light
# Add MLPs to PBR kwargs for access during training (only if MLPs exist)
if is_pbr and args.use_mlp and hasattr(gaussians, 'base_color_mlp'):
pbr_kwargs["base_color_mlp"] = gaussians.base_color_mlp
pbr_kwargs["roughness_mlp"] = gaussians.roughness_mlp
pbr_kwargs["metallic_mlp"] = gaussians.metallic_mlp
pbr_kwargs["projected_base_colors"] = gaussians.projected_base_colors
pbr_kwargs["projected_roughness"] = gaussians.projected_roughness
pbr_kwargs["projected_metallic"] = gaussians.projected_metallic
""" Prepare render function and bg"""
render_fn = render_fn_dict[args.type]
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
# Optional: dump all per-view normal GTs at the start using the exact training transformation
if getattr(args, 'dump_normals_at_start', False) and getattr(args, 'normal_gt_folder', None) is not None:
try:
dump_dir = args.dump_normals_dir or os.path.join(args.model_path, "normal_gt_dump")
os.makedirs(dump_dir, exist_ok=True)
training_cameras = scene.getTrainCameras()
print(f"\n=== Dumping {len(training_cameras)} normal GT images to: {dump_dir} ===")
for camera in training_cameras:
try:
# Resolve per-camera normal path (same logic as main loop)
original_basename = camera.image_name
number_part = os.path.splitext(original_basename)[0]
number_parts = [number_part]
if number_part.startswith("r_"):
number_parts.append("r_r_" + number_part[2:])
elif number_part.startswith("r_r_"):
number_parts.append("r_" + number_part[4:])
normal_filename_options = []
for num_part in number_parts:
normal_filename_options.extend([
f"{num_part}_normalcamera0001.exr",
f"{num_part}_normal0001.exr",
f"{num_part}_normal.exr",
f"{num_part}_normal.png"
])
normal_gt_path = None
normal_folder = args.normal_gt_folder
if normal_folder.startswith('/'):
normal_folder = normal_folder[1:]
for fname in normal_filename_options:
p = os.path.join(dataset.source_path, normal_folder, fname)
if os.path.exists(p):
normal_gt_path = p
break
if normal_gt_path is None:
print(f"[DumpNormals] Missing normal GT for {camera.image_name}")
continue
device = background.device
exr_alpha_mask = None
# Load and transform normals to world space if needed
if getattr(args, 'normal_gt_is_camera_space', False):
normal_gt = load_normal_image(normal_gt_path).to(device) # [3,H,W] in [-1,1]
Hn, Wn = normal_gt.shape[1], normal_gt.shape[2]
# EXR alpha read if available
if normal_gt_path.endswith('.exr'):
try:
file = OpenEXR.InputFile(normal_gt_path)
dw = file.header()['dataWindow']
size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
if 'A' in file.header()['channels']:
a = np.array(array.array('f', file.channel('A', FLOAT))).reshape(size[1], size[0])
exr_alpha_mask = torch.from_numpy(a).float().view(1, size[1], size[0]).to(device)
file.close()
except Exception:
exr_alpha_mask = None
camera_rotation_matrix = get_camera_to_world_rotation_matrix(camera)
camera_rotation_tensor = torch.from_numpy(camera_rotation_matrix).float().to(device)
normal_flat = normal_gt.permute(1, 2, 0).reshape(-1, 3)
world_flat = torch.matmul(normal_flat, camera_rotation_tensor)
world_flat = F.normalize(world_flat, dim=-1, eps=1e-6)
normal_world = world_flat.reshape(Hn, Wn, 3).permute(2, 0, 1)
else:
if normal_gt_path.endswith('.exr'):
file = OpenEXR.InputFile(normal_gt_path)
dw = file.header()['dataWindow']
size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
r = np.array(array.array('f', file.channel('R', FLOAT))).reshape(size[1], size[0])
g = np.array(array.array('f', file.channel('G', FLOAT))).reshape(size[1], size[0])
b = np.array(array.array('f', file.channel('B', FLOAT))).reshape(size[1], size[0])
try:
if 'A' in file.header()['channels']:
a = np.array(array.array('f', file.channel('A', FLOAT))).reshape(size[1], size[0])
exr_alpha_mask = torch.from_numpy(a).float().view(1, size[1], size[0]).to(device)
except Exception:
exr_alpha_mask = None
file.close()
normal_world = torch.from_numpy(np.stack([r, g, b], axis=0)).float().to(device)
else:
normal_world = torchvision.io.read_image(normal_gt_path).float().to(device)
if normal_world.shape[0] >= 4:
exr_alpha_mask = normal_world[3:4].clone()
normal_world = normal_world[:3]
if normal_world.max() > 1.0:
normal_world = normal_world / 255.0
if getattr(args, 'normal_gt_range', '-1_1') == '0_1':
normal_world = normal_world * 2.0 - 1.0
normal_world = F.normalize(normal_world, dim=0, eps=1e-6)
# Map to [0,1]
target_rgb = torch.clamp(normal_world * 0.5 + 0.5, 0.0, 1.0)
# Load mask directly from the original image file to ensure we get the correct alpha channel
# This is more reliable than relying on camera.image_mask which might not be set correctly
img_mask = None
# Determine the correct path to load mask from
# If images.__backup__ exists, the original images with masks are there
# (this happens when the script has already swapped images folder with normals)
if hasattr(camera, 'image_path') and camera.image_path is not None:
image_path = camera.image_path
# Check if there's a backup folder with original images
# The backup folder is typically images.__backup__ in the same directory
image_dir = os.path.dirname(image_path)
image_filename = os.path.basename(image_path)
backup_dir = image_dir + ".__backup__"
backup_path = os.path.join(backup_dir, image_filename)
# Prefer backup path if it exists (original images with masks)
if os.path.exists(backup_path):
mask_source_path = backup_path
print(f"[DumpNormals] Loading mask from backup: {backup_path}")
elif os.path.exists(image_path):
mask_source_path = image_path
else:
mask_source_path = None
if mask_source_path is not None:
try:
# Use torchvision to read image (returns [C, H, W] tensor with values 0-255)
orig_img = torchvision.io.read_image(mask_source_path)
if orig_img.shape[0] == 4:
# Extract alpha channel as mask
alpha = orig_img[3:4, :, :].float() / 255.0
img_mask = alpha.to(device)
else:
print(f"[DumpNormals] Image at {mask_source_path} has {orig_img.shape[0]} channels, no alpha")
except Exception as e:
print(f"[DumpNormals] Could not load mask from {mask_source_path}: {e}")
# Second try: use camera.image_mask if available
if img_mask is None and hasattr(camera, 'image_mask') and camera.image_mask is not None:
img_mask = camera.image_mask.to(device)
# Ensure mask has correct shape [1, H, W]
if img_mask.dim() == 2:
img_mask = img_mask.unsqueeze(0)
elif img_mask.dim() == 3 and img_mask.shape[0] != 1:
img_mask = img_mask[:1, ...]
# Fallback: all ones mask
if img_mask is None:
img_mask = torch.ones((1, target_rgb.shape[1], target_rgb.shape[2]), device=device, dtype=target_rgb.dtype)
print(f"[DumpNormals] Warning: No mask found for {camera.image_name}, using all-ones mask")
# Resize mask to match target_rgb if needed
if img_mask.shape[1:] != target_rgb.shape[1:]:
img_mask = F.interpolate(img_mask[None], size=target_rgb.shape[1:], mode='nearest')[0]
# Ensure mask is in [0, 1] range
img_mask = img_mask.clamp(0.0, 1.0)
# Apply mask to target_rgb (multiply RGB channels with mask)
mask_3c = img_mask.repeat(3, 1, 1) if img_mask.shape[0] == 1 else img_mask
target_rgb_masked = target_rgb * mask_3c
# Save only RGBA (RGB masked, A = mask)
rgba = torch.cat([target_rgb_masked.clamp(0,1), img_mask.clamp(0,1)], dim=0) # 4,H,W
save_image(rgba, os.path.join(dump_dir, f"{camera.image_name}.png"))
except Exception as e:
print(f"[DumpNormals] Error for {camera.image_name}: {e}")
print(f"=== Normal GT dump complete ===\n")
if getattr(args, 'dump_and_exit', False):
print("Exiting after normal GT dump as requested (--dump_and_exit).")
return
except Exception as e:
print(f"[DumpNormals] Failed to dump normals at start: {e}")
"""
Perform intersection tracing for all training images
"""
if args.perform_intersection_tracing:
print("\n=== Performing Intersection Tracing ===")
try:
gaussians.intersection_tracing = True
# Get image dimensions from first camera
first_camera = scene.getTrainCameras()[0]
width = first_camera.image_width
height = first_camera.image_height
# Perform intersection tracing
all_gaussian_base_colors, all_gaussian_roughness, all_gaussian_metallic, all_gaussian_normals, gaussian_hit_mask = perform_intersection_tracing_for_all_training_images(
gaussians, scene, dataset.model_path, width, height, args, dataset)
print("Shape of the gaussian base colors tensor: ", all_gaussian_base_colors.shape)
print("Shape of the gaussian roughness tensor: ", all_gaussian_roughness.shape)
print("Shape of the gaussian metallic tensor: ", all_gaussian_metallic.shape)
# Compute mean across all camera views, but only for Gaussians that were actually hit
# For each Gaussian, only average over cameras where it was hit
# Compute mean across all camera views, but only for Gaussians that were actually hit
print("\n=== Computing Mean Values ===")
# Count how many cameras each Gaussian was hit by
hit_counts_per_gaussian = gaussian_hit_mask.sum(dim=1) # [num_gaussians]
# Get number of Gaussians from tensor shape
num_gaussians = all_gaussian_base_colors.shape[0]
# Compute mean only over valid hits (avoid division by zero)
mean_base_colors = torch.zeros((num_gaussians, 3), device='cuda', dtype=torch.float32)
mean_roughness = torch.zeros((num_gaussians, 1), device='cuda', dtype=torch.float32)
mean_metallic = torch.zeros((num_gaussians, 1), device='cuda', dtype=torch.float32)
mean_normals = torch.zeros((num_gaussians, 3), device='cuda', dtype=torch.float32)
# Only compute mean for Gaussians that were hit at least once
valid_gaussians = hit_counts_per_gaussian > 0
if valid_gaussians.any():
mean_base_colors[valid_gaussians] = (
all_gaussian_base_colors[valid_gaussians].sum(dim=1) /
hit_counts_per_gaussian[valid_gaussians].unsqueeze(-1).float()
)
mean_roughness[valid_gaussians] = (
all_gaussian_roughness[valid_gaussians].sum(dim=1) /
hit_counts_per_gaussian[valid_gaussians].unsqueeze(-1).float()
)
mean_metallic[valid_gaussians] = (
all_gaussian_metallic[valid_gaussians].sum(dim=1) /
hit_counts_per_gaussian[valid_gaussians].unsqueeze(-1).float()
)
mean_normals[valid_gaussians] = (
all_gaussian_normals[valid_gaussians].sum(dim=1) /
hit_counts_per_gaussian[valid_gaussians].unsqueeze(-1).float()
)
print(f"Mean values computed for {valid_gaussians.sum().item()}/{num_gaussians} Gaussians")
print(f"Average hits per Gaussian: {hit_counts_per_gaussian.float().mean().item():.2f}")
print(f"Mean base colors shape: {mean_base_colors.shape}")
print(f"Mean roughness shape: {mean_roughness.shape}")
print(f"Mean metallic shape: {mean_metallic.shape}")
print(f"Mean normals shape: {mean_normals.shape}")
# Delete non-hit Gaussians completely
invalid_gaussians = hit_counts_per_gaussian == 0
if invalid_gaussians.any():
num_invalid = invalid_gaussians.sum().item()
valid_gaussians = ~invalid_gaussians
print(f"Deleting {num_invalid} non-hit Gaussians completely (keeping {valid_gaussians.sum().item()})")
# Filter mean colors/roughness/metallic to only include valid gaussians
mean_base_colors = mean_base_colors[valid_gaussians]
mean_roughness = mean_roughness[valid_gaussians]
mean_metallic = mean_metallic[valid_gaussians]
mean_normals = mean_normals[valid_gaussians]
# Delete invalid gaussians from the model
gaussians.prune_points(invalid_gaussians)
# Update projected data to match the new number of gaussians (only if MLPs are being used)
if hasattr(gaussians, 'base_color_mlp') and hasattr(gaussians, 'projected_base_colors'):
gaussians.projected_base_colors = gaussians.projected_base_colors[valid_gaussians]
gaussians.projected_roughness = gaussians.projected_roughness[valid_gaussians]
gaussians.projected_metallic = gaussians.projected_metallic[valid_gaussians]
if hasattr(gaussians, 'projected_normals'):
gaussians.projected_normals = gaussians.projected_normals[valid_gaussians]
print("✓ Projected data updated after pruning")
# Recompute visibility after pruning since cached incident directions and areas have wrong dimensions
print("Recomputing visibility after pruning...")
gaussians.update_visibility(pipe.sample_num)
print("✓ Visibility recomputed with correct dimensions")
# Apply intersection-traced values to the remaining gaussians
with torch.no_grad():
if hasattr(gaussians, 'base_color_mlp'):
# Update projected data with intersection-traced values
# Vectorized update of projected data with intersection-traced values
need_filter = all_gaussian_base_colors.shape[0] != mean_base_colors.shape[0]
if need_filter:
agb = all_gaussian_base_colors[valid_gaussians]
agr = all_gaussian_roughness[valid_gaussians]
agm = all_gaussian_metallic[valid_gaussians]
agn = all_gaussian_normals[valid_gaussians]
ghm = gaussian_hit_mask[valid_gaussians]
else:
agb = all_gaussian_base_colors
agr = all_gaussian_roughness
agm = all_gaussian_metallic
agn = all_gaussian_normals
ghm = gaussian_hit_mask
inv_mask = ~ghm # [G, C]
gaussians.projected_base_colors = torch.where(inv_mask[..., None], mean_base_colors[:, None, :], agb)
gaussians.projected_roughness = torch.where(inv_mask[..., None], mean_roughness[:, None, :], agr)
gaussians.projected_metallic = torch.where(inv_mask[..., None], mean_metallic[:, None, :], agm)
if args is not None and getattr(args, 'normals_folder', None) is not None and hasattr(gaussians, 'projected_normals'):
filled_normals = torch.where(inv_mask[..., None], mean_normals[:, None, :], agn)
gaussians.projected_normals = filled_normals
print("✓ Projected data updated with intersection-traced values (normals kept if no normals_folder)")
else:
# Store average values directly in gaussian parameters (no projected data needed)
gaussians._base_color.data = mean_base_colors
gaussians._roughness.data = mean_roughness
gaussians._metallic.data = mean_metallic
# Ensure normals are normalized before storing
gaussians._normal.data = F.normalize(mean_normals, dim=-1, eps=1e-6)
# Set flag to indicate intersection tracing has been used (colors are in linear space)
gaussians.intersection_tracing = True
print("✓ Average intersection-traced values stored directly in gaussian parameters (normals normalized)")
print("✓ Base colors, roughness, metallic and normal values assigned from intersection tracing")
print(f"✓ Final model has {gaussians.get_xyz.shape[0]} gaussians after pruning")
print("✓ Intersection tracing completed successfully")
except Exception as e:
print(f"✗ Error during intersection tracing: {e}")
import traceback
traceback.print_exc()
print("Continuing with training without intersection data...")
elif args.load_intersection_data:
print("\n=== Loading Existing Intersection Data ===")
try:
# Load existing intersection data
intersection_data = load_intersection_data(dataset.model_path)
if intersection_data is not None:
# Store intersection data in gaussians for potential use during training
gaussians.intersection_data = intersection_data
print("✓ Intersection data loaded successfully")
else:
print("✗ No intersection data found to load")
except Exception as e:
print(f"✗ Error loading intersection data: {e}")
import traceback
traceback.print_exc()
print("Continuing with training without intersection data...")
""" Training """
viewpoint_stack = None
ema_dict_for_log = defaultdict(int)
# Choose a single camera for timelapse at the beginning
timelapse_camera = None
if pipe.enable_timelapse:
training_cameras = scene.getTrainCameras()
if len(training_cameras) > 0:
# Use specified camera index, or first camera if index is out of range
camera_index = min(pipe.timelapse_camera_index, len(training_cameras) - 1)
timelapse_camera = training_cameras[camera_index]
print(f"✓ Selected camera '{timelapse_camera.image_name}' (index {camera_index}) for timelapse")
# Print timelapse configuration
if pipe.use_dynamic_timelapse:
print(f"📹 Dynamic timelapse intervals:")
print(f" Iterations 1-{pipe.timelapse_breakpoint_1}: every {pipe.timelapse_interval_early} iterations")
print(f" Iterations {pipe.timelapse_breakpoint_1+1}-{pipe.timelapse_breakpoint_2}: every {pipe.timelapse_interval_mid} iterations")
print(f" Iterations {pipe.timelapse_breakpoint_2+1}-{pipe.timelapse_breakpoint_3}: every {pipe.timelapse_interval_late} iterations")
print(f" Iterations {pipe.timelapse_breakpoint_3+1}+: every {pipe.timelapse_interval_final} iterations")
else:
print(f"📹 Fixed timelapse interval: every {pipe.timelapse_interval} iterations")
else:
print("Warning: No training cameras found for timelapse!")
elif not pipe.enable_timelapse:
print("ℹ️ Timelapse functionality disabled")
progress_bar = tqdm(range(first_iter + 1, opt.iterations + 1), desc="Training progress",
initial=first_iter, total=opt.iterations)
for iteration in progress_bar:
gaussians.update_learning_rate(iteration)
# Clear combined MLP cache at the start of each iteration
if hasattr(gaussians, 'clear_combined_mlp_cache'):
gaussians.clear_combined_mlp_cache()
# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
gaussians.oneupSHdegree()
# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
# Initialize loss as a Tensor to avoid .backward() errors when skipping photo loss
# Use a known device (background tensor) since viewpoint_cam isn't set yet
init_device = background.device if isinstance(background, torch.Tensor) else ("cuda" if torch.cuda.is_available() else "cpu")
loss = torch.tensor(0.0, device=init_device)
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
# Render
if (iteration - 1) == args.debug_from:
pipe.debug = True
pbr_kwargs["iteration"] = iteration - first_iter
# Provide deferred renderer with conservative pixel and sample chunking
total_samples = int(pbr_kwargs.get("sample_num", getattr(pipe, "sample_num", 32)))
# Stream samples across batches to keep peak memory low
samples_per_batch = max(1, min(8, total_samples))
num_batches = max(1, (total_samples + samples_per_batch - 1) // samples_per_batch)
# Pixel chunking (ensure we don't process all pixels at once)
deferred_options = {
"max_pixels_per_pass": 32768,
"samples_per_batch": samples_per_batch,
"num_batches": num_batches,
}
render_pkg = render_fn(viewpoint_cam, gaussians, pipe, background,
opt=opt, is_training=True, dict_params=pbr_kwargs,
iteration=iteration, deferred_options=deferred_options)
viewspace_point_tensor, visibility_filter, radii = \
render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
# Loss
tb_dict = render_pkg["tb_dict"]
if not getattr(args, 'supervise_with_normals_only', False) and not getattr(args, 'use_normals_as_rgb_gt', False):
loss += render_pkg["loss"]
# If using normals as RGB GT, compute and add replacement image loss here (works for both PBR and base 3DGS)
if getattr(args, 'use_normals_as_rgb_gt', False) and getattr(args, 'normal_gt_folder', None) is not None:
try:
# Build candidate normal filenames
original_basename = viewpoint_cam.image_name
number_part = os.path.splitext(original_basename)[0]
number_parts = [number_part]
if number_part.startswith("r_"):
number_parts.append("r_r_" + number_part[2:])
elif number_part.startswith("r_r_"):
number_parts.append("r_" + number_part[4:])
normal_filename_options = []
for num_part in number_parts:
normal_filename_options.extend([
f"{num_part}_normalcamera0001.exr",
f"{num_part}_normal0001.exr",
f"{num_part}_normal.exr",
f"{num_part}_normal.png"
])
normal_gt_path = None
normal_folder = args.normal_gt_folder
if normal_folder.startswith('/'):
normal_folder = normal_folder[1:]
for fname in normal_filename_options:
p = os.path.join(dataset.source_path, normal_folder, fname)
if os.path.exists(p):
normal_gt_path = p
break
if normal_gt_path is not None:
device = viewpoint_cam.original_image.device
exr_alpha_mask = None # optional alpha from EXR/PNG
# Load and transform normals to world space if needed
if getattr(args, 'normal_gt_is_camera_space', False):
normal_gt = load_normal_image(normal_gt_path).to(device) # [3,H,W] in [-1,1], normalized
Hn, Wn = normal_gt.shape[1], normal_gt.shape[2]
# If EXR, try to read alpha channel as mask
if normal_gt_path.endswith('.exr'):
try:
file = OpenEXR.InputFile(normal_gt_path)
dw = file.header()['dataWindow']
size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
if 'A' in file.header()['channels']:
a = np.array(array.array('f', file.channel('A', FLOAT))).reshape(size[1], size[0])
exr_alpha_mask = torch.from_numpy(a).float().view(1, size[1], size[0]).to(device)
file.close()
except Exception:
exr_alpha_mask = None
camera_rotation_matrix = get_camera_to_world_rotation_matrix(viewpoint_cam)
camera_rotation_tensor = torch.from_numpy(camera_rotation_matrix).float().to(device)
normal_flat = normal_gt.permute(1, 2, 0).reshape(-1, 3)
world_flat = torch.matmul(normal_flat, camera_rotation_tensor)
world_flat = F.normalize(world_flat, dim=-1, eps=1e-6)
normal_world = world_flat.reshape(Hn, Wn, 3).permute(2, 0, 1)
else:
# Load as raw tensor and map range
if normal_gt_path.endswith('.exr'):
file = OpenEXR.InputFile(normal_gt_path)
dw = file.header()['dataWindow']
size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
r = np.array(array.array('f', file.channel('R', FLOAT))).reshape(size[1], size[0])
g = np.array(array.array('f', file.channel('G', FLOAT))).reshape(size[1], size[0])
b = np.array(array.array('f', file.channel('B', FLOAT))).reshape(size[1], size[0])
# optional alpha
try:
if 'A' in file.header()['channels']:
a = np.array(array.array('f', file.channel('A', FLOAT))).reshape(size[1], size[0])
exr_alpha_mask = torch.from_numpy(a).float().view(1, size[1], size[0]).to(device)
except Exception:
exr_alpha_mask = None
file.close()
normal_world = torch.from_numpy(np.stack([r, g, b], axis=0)).float().to(device)
else:
normal_world = torchvision.io.read_image(normal_gt_path).float().to(device)
if normal_world.shape[0] >= 4:
exr_alpha_mask = normal_world[3:4].clone()
normal_world = normal_world[:3]
if normal_world.max() > 1.0:
normal_world = normal_world / 255.0
if getattr(args, 'normal_gt_range', '-1_1') == '0_1':
normal_world = normal_world * 2.0 - 1.0
normal_world = F.normalize(normal_world, dim=0, eps=1e-6)
# Map to [0,1] RGB
target_rgb = torch.clamp(normal_world * 0.5 + 0.5, 0.0, 1.0)
# Predicted image
pred_image = render_pkg["pbr"] if is_pbr else render_pkg["render"]
if pred_image.shape[1:] != target_rgb.shape[1:]:
target_rgb = F.interpolate(target_rgb[None], size=pred_image.shape[1:], mode='bilinear', align_corners=True)[0]
# Extract mask from alpha channel of ground truth image
gt_image_full = viewpoint_cam.original_image.to(device)
# If GT image has 4 channels (RGBA), extract alpha as mask
if gt_image_full.dim() == 3 and gt_image_full.shape[0] == 4:
# Extract alpha channel as mask
img_mask = gt_image_full[3:4, ...].clamp(0.0, 1.0)
else:
# If no alpha channel, use all ones as mask
img_mask = torch.ones((1, gt_image_full.shape[1], gt_image_full.shape[2]), device=device, dtype=gt_image_full.dtype)
# Resize mask to match pred_image if needed
if img_mask.shape[1:] != pred_image.shape[1:]:
img_mask = F.interpolate(img_mask[None], size=pred_image.shape[1:], mode='nearest')[0]
mask3 = img_mask.repeat(3, 1, 1)
# Apply mask to target_rgb (multiply RGB channels with mask)
target_rgb = target_rgb * mask3
# Apply mask to BOTH pred and target; normalize by mask coverage
diff = (pred_image - target_rgb).abs() * mask3
denom = torch.clamp(mask3.sum(), min=1e-6)
loss += diff.sum() / denom
# Debug print/save
if iteration % 50 == 0 or iteration == first_iter + 1:
try:
tmin = [float(target_rgb[c].min().item()) for c in range(3)]
tmax = [float(target_rgb[c].max().item()) for c in range(3)]
tmean = [float(target_rgb[c].mean().item()) for c in range(3)]
cov = float(img_mask.mean().item())
print(f"[Normals-as-RGB GT] Using: {normal_gt_path}")
print(f"[Normals-as-RGB GT] stats min={tmin} max={tmax} mean={tmean} mask_coverage={cov:.4f}")
if getattr(args, 'save_training_vis', False):
vis_dir_rgb = os.path.join(args.model_path, "visualize", "gt_normals_rgb")
vis_dir_rgba = os.path.join(args.model_path, "visualize", "gt_normals_rgba")
os.makedirs(vis_dir_rgb, exist_ok=True)
os.makedirs(vis_dir_rgba, exist_ok=True)
# Save unclamped RGB target
save_image(target_rgb.clamp(0,1), os.path.join(vis_dir_rgb, f"{iteration:06d}_{viewpoint_cam.image_name}.png"))
# Save the exact tensor used for loss with alpha channel
alpha = img_mask[0:1].clamp(0,1) # 1,H,W
normal_gt_rgba = torch.cat([target_rgb.clamp(0,1), alpha], dim=0) # 4,H,W
save_image(normal_gt_rgba, os.path.join(vis_dir_rgba, f"{iteration:06d}_{viewpoint_cam.image_name}.png"))
except Exception:
pass
except Exception as e:
print(f"[ERROR] Normals-as-RGB GT (main loss) - Error processing {viewpoint_cam.image_name}: {e}")
# Extract mask once from the ground truth render for use with all buffers
gt_image_full = viewpoint_cam.original_image.cuda()
# Check if ground truth has alpha channel (4 channels)
if gt_image_full.shape[0] == 4:
# Extract RGB channels
gt_image = gt_image_full[:3, ...]
# Extract alpha channel as mask
gt_mask = gt_image_full[3:4, ...].clamp(0.0, 1.0)
else:
# If no alpha, create a mask of all ones
gt_image = gt_image_full
if gt_image.dim() == 3 and gt_image.shape[0] > 3:
gt_image = gt_image[:3, ...]
gt_mask = torch.ones((1, gt_image.shape[1], gt_image.shape[2])).to("cuda")
# Apply mask to RGB channels of ground truth
gt_mask_3c = gt_mask.repeat(3, 1, 1) if gt_mask.shape[0] == 1 else gt_mask
gt_image = gt_image * gt_mask_3c
# Optional normal GT supervision comparing predicted normals (skip when using normals as RGB GT)
if getattr(args, 'normal_gt_folder', None) is not None and not getattr(args, 'use_normals_as_rgb_gt', False):
try:
# Build candidate filenames from camera name
original_basename = viewpoint_cam.image_name
number_part = os.path.splitext(original_basename)[0]
number_parts = [number_part]
if number_part.startswith("r_"):
number_parts.append("r_r_" + number_part[2:])
elif number_part.startswith("r_r_"):
number_parts.append("r_" + number_part[4:])
normal_filename_options = []
for num_part in number_parts:
# Common patterns
normal_filename_options.extend([
f"{num_part}_normalcamera0001.exr",
f"{num_part}_normal0001.exr",
f"{num_part}_normal.exr",
f"{num_part}_normal.png"
])
normal_gt_path = None
normal_folder = args.normal_gt_folder
if normal_folder.startswith('/'):