-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_ours_adaptive_lambda.py
More file actions
961 lines (777 loc) · 41 KB
/
Copy pathmain_ours_adaptive_lambda.py
File metadata and controls
961 lines (777 loc) · 41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
# main_ours.py
import os
import cv2
import time
import gc
import tqdm
import numpy as np
import random
import dearpygui.dearpygui as dpg
import torch
import torch.nn.functional as F
from cam_utils import orbit_camera, OrbitCamera
from gs_renderer import Renderer, MiniCam
from grid_put import mipmap_linear_grid_put_2d
from mesh import Mesh, safe_normalize
from guidance.sd_utils import StableDiffusion
from visualizer import GaussianVisualizer
from metrics import MetricsCalculator
from feature_extractor import DINOv2MultiLayerFeatureExtractor
from kernels import rbf_kernel_and_grad, cosine_kernel_and_grad
class GUI:
def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.seed = opt.seed
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # update buffer_image
# models
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Set CUDA device if specified in options
if hasattr(opt, 'gpu_id') and opt.gpu_id is not None and torch.cuda.is_available():
torch.cuda.set_device(opt.gpu_id)
self.device = torch.device(f"cuda:{opt.gpu_id}")
print(f"[INFO] Using GPU {opt.gpu_id}: {torch.cuda.get_device_name(opt.gpu_id)}")
self.bg_remover = None
self.guidance_sd = None
self.enable_sd = False
# renderer
# self.renderer = Renderer(sh_degree=self.opt.sh_degree)
self.renderers = []
for i in range(self.opt.num_particles):
self.renderers.append(Renderer(sh_degree=self.opt.sh_degree))
self.gaussain_scale_factor = 1
# input image
self.input_img = None
self.input_mask = None
self.input_img_torch = None
self.input_mask_torch = None
self.overlay_input_img = False
self.overlay_input_img_ratio = 0.5
# input text
self.prompt = ""
self.negative_prompt = ""
# training stuff
self.training = False
self.optimizer = None
self.step = 0
self.train_steps = 1 # steps per rendering loop
# override prompt from cmdline
if self.opt.prompt is not None:
self.prompt = self.opt.prompt
if self.opt.negative_prompt is not None:
self.negative_prompt = self.opt.negative_prompt
self.seeds = []
# override if provide a checkpoint
for i in range(self.opt.num_particles):
# Set different seed for each particle during initialization
init_seed = self.seed + i * 1000000
self.seeds.append(init_seed)
self.seed_everything(init_seed)
if self.opt.load is not None:
self.renderers[i].initialize(self.opt.load) # TODO: load from different checkpoints for each particle
else:
# initialize gaussians to a blob
self.renderers[i].initialize(num_pts=self.opt.num_pts)
# visualizer
if self.opt.visualize or self.opt.metrics:
self.visualizer = GaussianVisualizer(opt=self.opt, renderers=self.renderers, cam=self.cam)
else:
self.visualizer = None
self.repctl = {
"target": float(getattr(opt, "rep_ratio_target", 40.0)),# % 목표(가운데값)
"low": float(getattr(opt, "rep_ratio_low", opt.rep_ratio_target - 2.0)), # % 하한
"high": float(getattr(opt, "rep_ratio_high", opt.rep_ratio_target + 2.0)), # % 상한
"ema": float(getattr(opt, "rep_ratio_ema", 0.9)), # EMA 계수 (0.9~0.98 권장)
"interval": int(getattr(opt, "rep_update_interval", 10)), # 몇 스텝마다 갱신할지
"warmup": int(getattr(opt, "rep_warmup_steps", 50)), # 초기 워밍업 스텝
"k": float(getattr(opt, "rep_ratio_k", 0.12)), # 조정 강도 (작을수록 완만)
"min": float(getattr(opt, "lambda_repulsion_min", 1e-6)),
"max": float(getattr(opt, "lambda_repulsion_max", 4000.0)),
"step_cap": float(getattr(opt, "rep_step_mult_cap", 1.25)), # 1회 갱신시 배수 상한
"denom_floor": float(getattr(opt, "rep_ratio_denom_floor", 1e-6)),
}
self._rep_ratio_ema = None
def __del__(self):
pass
def cleanup_gpu_resources(self):
"""Clean up GPU resources and free memory."""
print("[INFO] Cleaning up GPU resources...")
# Clean up renderers
if hasattr(self, 'renderers'):
for renderer in self.renderers:
if hasattr(renderer, 'gaussians'):
del renderer.gaussians
del renderer
self.renderers.clear()
# Clean up visualizer
if hasattr(self, 'visualizer'):
del self.visualizer
self.visualizer = None
# Clean up feature extractor
if hasattr(self, 'feature_extractor'):
del self.feature_extractor
self.feature_extractor = None
# Clean up guidance models
if hasattr(self, 'guidance_sd'):
del self.guidance_sd
self.guidance_sd = None
# Clean up metrics calculator
if hasattr(self, 'metrics_calculator'):
del self.metrics_calculator
self.metrics_calculator = None
# Clean up optimizers
if hasattr(self, 'optimizers'):
for optimizer in self.optimizers:
del optimizer
self.optimizers.clear()
# Force garbage collection and CUDA cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("[INFO] GPU resources cleaned up successfully.")
def set_gpu_device(self, gpu_id: int):
"""Set the CUDA device for this experiment."""
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
torch.cuda.set_device(gpu_id)
self.device = torch.device(f"cuda:{gpu_id}")
print(f"[INFO] Switched to GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
return True
else:
print(f"[WARNING] GPU {gpu_id} not available, using CPU")
self.device = torch.device("cpu")
return False
# def seed_everything(self, seed):
# try:
# seed = int(seed)
# except:
# seed = np.random.randint(0, 1000000)
# os.environ["PYTHONHASHSEED"] = str(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.use_deterministic_algorithms(True)
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.allow_tf32 = False
def seed_everything(self, seed):
try:
seed = int(seed)
except:
seed = np.random.randint(0, 1000000)
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed) # ← 추가
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # ← 추가(멀티 GPU 대비)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# def adaptive_lambda_repulsion(self, ratio_pct: float):
# """EMA된 비율을 타깃(%)에 가깝게 유지하도록 lambda_repulsion을 곱셈형으로 조정."""
# # EMA 업데이트
# if self._rep_ratio_ema is None:
# self._rep_ratio_ema = ratio_pct
# else:
# a = self.repctl["ema"]
# self._rep_ratio_ema = a * self._rep_ratio_ema + (1 - a) * ratio_pct
# # EMA 업데이트 뒤
# low, high = self.repctl["low"], self.repctl["high"]
# if low <= self._rep_ratio_ema <= high:
# return
# # 워밍업/주기 체크
# if self.step < self.repctl["warmup"] or (self.step % self.repctl["interval"] != 0):
# return
# target = self.repctl["target"]
# if target <= 0:
# return
# # 에러(상대값) 계산: target보다 낮으면 +, 높으면 -
# error = (target - self._rep_ratio_ema) / target
# # 곱셈형 업데이트 (exp), 1회 변화량 캡
# # error가 0.2면 exp(k*0.2)만큼 증가, 음수면 감소
# k = self.repctl["k"]
# raw_mult = np.exp(k * error)
# # 과도한 진동 방지를 위해 배수 클램프
# max_mult = self.repctl["step_cap"]
# min_mult = 1.0 / max_mult
# mult = float(np.clip(raw_mult, min_mult, max_mult))
# new_lambda = float(np.clip(self.opt.lambda_repulsion * mult, self.repctl["min"], self.repctl["max"]))
# if new_lambda != self.opt.lambda_repulsion:
# self.opt.lambda_repulsion = new_lambda
# # 디버그 로그(옵션)
# print(f"[CTRL] step {self.step}: ratio_ema={self._rep_ratio_ema:.2f}% -> "
# f"lambda_repulsion={self.opt.lambda_repulsion:.6f} (x{mult:.3f})")
def adaptive_lambda_repulsion(self, ratio_pct: float):
"""
ratio_pct: 100 * |scaled_repulsion_loss| / max(|scaled_attraction_loss|, eps)
목표: ratio_pct ~= self.repctl["target"] (예: 20.0)
EMA + dead-band + 곱셈형 업데이트(exp) + 1회 변화량 캡(step_cap) + [min,max] 클램프
"""
# ---- EMA 업데이트 ----
a = float(self.repctl.get("ema", 0.9))
if self._rep_ratio_ema is None:
self._rep_ratio_ema = float(ratio_pct)
else:
self._rep_ratio_ema = a * float(self._rep_ratio_ema) + (1.0 - a) * float(ratio_pct)
# ---- 워밍업/주기 ----
warmup = int(self.repctl.get("warmup", 50))
interval = int(self.repctl.get("interval", 10))
if self.step < warmup or (self.step % interval != 0):
return
# ---- 목표/데드밴드 ----
target = float(self.repctl.get("target", 20.0)) # ← 20%로 맞추려면 설정에서 rep_ratio_target: 20.0
if target <= 0:
return
# dead-band: 설정에 low/high가 있으면 사용, 없으면 target ±2% 자동 생성
low = float(self.repctl.get("low", target - 2.0))
high = float(self.repctl.get("high", target + 2.0))
ema_ratio = float(self._rep_ratio_ema)
# dead-band 안이면 조정 없음
if low <= ema_ratio <= high:
return
# ---- 곱셈형 업데이트 (exp) ----
# target보다 낮으면 lambda를 키우고(>1), 높으면 줄이기(<1)
k = float(self.repctl.get("k", 0.12))
raw_mult = np.exp(k * (target - ema_ratio) / max(target, 1e-6))
# 1회 변화량 캡
step_cap = float(self.repctl.get("step_cap", 1.25))
mult = float(np.clip(raw_mult, 1.0 / step_cap, step_cap))
# [min, max] 범위 보호
lam_min = float(self.repctl.get("min", 0.0))
lam_max = float(self.repctl.get("max", 5.0))
new_lambda = float(np.clip(self.opt.lambda_repulsion * mult, lam_min, lam_max))
# 변경 적용 (미세한 변동 무시)
if abs(new_lambda - self.opt.lambda_repulsion) > 1e-12:
self.opt.lambda_repulsion = new_lambda
print(f"[CTRL] step {self.step}: ratio_ema={ema_ratio:.2f}% "
f"target={target:.2f}% mult={mult:.3f} -> lambda_repulsion={self.opt.lambda_repulsion:.6f}")
def prepare_train(self):
self.step = 0
self.optimizers = []
for i in range(self.opt.num_particles):
self.renderers[i].gaussians.training_setup(self.opt)
self.renderers[i].gaussians.active_sh_degree = self.renderers[i].gaussians.max_sh_degree
self.optimizers.append(self.renderers[i].gaussians.optimizer)
# feature extractor
# Use multi-layer extractor for specific layer extraction
# Supports: 'early' (25% depth), 'mid' (50% depth), 'last' (final layer)
if self.opt.repulsion_type != 'wo':
print(f"[INFO] Using DINOv2 features from '{self.opt.feature_layer}' layer")
self.feature_extractor = DINOv2MultiLayerFeatureExtractor(
model_name=self.opt.feature_extractor_model_name,
device=self.device
)
# Freeze feature extractor weights
self.feature_extractor.model.eval()
for param in self.feature_extractor.model.parameters():
param.requires_grad = False
# Make sure the feature layer is an integer and convert str to int if necessary
n_layers = len(self.feature_extractor.model.encoder.layer)
fl = self.opt.feature_layer
if isinstance(fl, str):
fl = fl.lower()
if fl == 'early':
self.opt.feature_layer = max(0, int(0.25 * n_layers) - 1) # 12 / 4 - 1 = 2 -> 2
elif fl == 'mid':
self.opt.feature_layer = max(0, int(0.50 * n_layers) - 1) # 12 / 2 - 1 = 5 -> 5
elif fl in ['last', 'final', 'late']:
self.opt.feature_layer = n_layers - 1 # 12 - 1 = 11 -> 11
else:
raise ValueError(f"Unknown feature_layer '{self.opt.feature_layer}' — use early|mid|last or an integer")
# metrics
if self.opt.metrics:
self.metrics_calculator = MetricsCalculator(opt=self.opt, prompt=self.prompt, device=self.device.type)
else:
self.metrics_calculator = None
print(f"[INFO] loading SD...")
self.guidance_sd = StableDiffusion(self.device)
# Freeze all model weights
for module in [self.guidance_sd.vae, self.guidance_sd.text_encoder, self.guidance_sd.unet]:
module.eval()
for param in module.parameters():
param.requires_grad = False
print(f"[INFO] loaded SD!")
# prepare embeddings
with torch.no_grad():
self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
def train_step(self):
self.step += 1
with torch.no_grad():
if self.step==1 or self.step % self.opt.efficiency_interval == 0 and self.opt.metrics and self.metrics_calculator is not None:
if torch.cuda.is_available():
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
else:
starter = None
ender = None
#########################################################
# Forward pass
#########################################################
# 0. Update step
# step_ratio = min(1, self.step / self.opt.iters)
step_ratio = min(1, self.step / self.opt.schedule_iters)
total_loss = 0
# 1.1. Render images (novel view)
render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512) # intial acceleration
images = []
outputs = []
# poses = []
# vers, hors, radii = [], [], []
# avoid too large elevation (> 80 or < -80), and make sure it always cover [min_ver, max_ver]
min_ver = max(min(self.opt.min_ver, self.opt.min_ver - self.opt.elevation), -80 - self.opt.elevation)
max_ver = min(max(self.opt.max_ver, self.opt.max_ver - self.opt.elevation), 80 - self.opt.elevation)
# render random view
self.seed_everything(self.seed + self.step)
ver = np.random.randint(min_ver, max_ver)
hor = np.random.randint(-180, 180)
radius = 0
# vers.append(ver)
# hors.append(hor)
# radii.append(radius)
pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
cur_cam = MiniCam(pose, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far)
bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device=self.device)
for j in range(self.opt.num_particles):
# set seed for each particle + iteration step for different background each iter
# update lr
self.renderers[j].gaussians.update_learning_rate(self.step)
out = self.renderers[j].render(cur_cam, bg_color=bg_color)
image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
images.append(image)
# Store output for each particle
# 필요한 키만 저장
outputs.append({
"viewspace_points": out["viewspace_points"],
"visibility_filter": out["visibility_filter"],
"radii": out["radii"],
})
# out["image"]는 images에 복사됐으니 out 전체는 버림
del out
images = torch.cat(images, dim=0) # [N, 3, H, W]
repulsion_loss = torch.tensor(0.0, device=self.device, dtype=torch.float32)
# repulsion loss (feature space)
if self.opt.repulsion_type in ['rlsd', 'svgd']:
features = self.feature_extractor.extract_cls_from_layer(self.opt.feature_layer, images).to(self.device).to(torch.float32) # [N, D_feature]
# 4. Kernel computation (feature space)
if self.opt.kernel_type == 'rbf':
kernel, kernel_grad = rbf_kernel_and_grad(
features,
repulsion_type=self.opt.repulsion_type,
beta=self.opt.rbf_beta,
) # kernel:[N,N], kernel_grad:[N, D_feature]
elif self.opt.kernel_type == 'cosine':
kernel, kernel_grad = cosine_kernel_and_grad(
features,
repulsion_type=self.opt.repulsion_type,
beta=self.opt.cosine_beta,
eps_shift=self.opt.cosine_eps_shift,
) # kernel:[N,N], kernel_grad:[N, D_feature]
else:
raise ValueError(f"Invalid kernel type: {self.opt.kernel_type}")
kernel = kernel.detach().to(torch.float32)
kernel_grad = kernel_grad.detach().to(torch.float32)
repulsion_loss = (kernel_grad * features).sum(dim=1) # [N, D_feature] * [N, D_feature] -> [N]
repulsion_loss = repulsion_loss.mean() # [N] -> [1]
# attraction loss (latent space)
attraction_loss = torch.tensor(0.0, device=self.device, dtype=torch.float32)
gen = torch.Generator(device=self.device).manual_seed(self.seed + self.step + 12345) # sd 전용 오프셋
score_gradients, latents = self.guidance_sd.train_step_gradient(
images, step_ratio=step_ratio if self.opt.anneal_timestep else None,
guidance_scale=self.opt.guidance_scale,
force_same_t=self.opt.force_same_t,
force_same_noise=self.opt.force_same_noise,
generator=gen,
) # score_gradients: [N, 4, 64, 64], latents: [N, 4, 64, 64]
score_gradients = score_gradients.to(torch.float32)
latents = latents.to(torch.float32)
if self.opt.repulsion_type == 'svgd':
v = torch.einsum('ij,jchw->ichw', kernel.detach(), score_gradients) # [N,4,64,64]
target = (latents - v).detach()
elif self.opt.repulsion_type in ['rlsd', 'wo']:
target = (latents - score_gradients).detach()
else:
raise ValueError(f"Invalid repulsion type: {self.opt.repulsion_type}")
attraction_loss = 0.5 * F.mse_loss(latents, target, reduction='none').view(self.opt.num_particles, -1).sum(dim=1) # [N]
attraction_loss = attraction_loss.mean()
# elif self.opt.repulsion_type in ['rlsd', 'wo']:
# for j in range(self.opt.num_particles):
# sds_loss = self.guidance_sd.train_step(images[j:j+1], step_ratio=step_ratio if self.opt.anneal_timestep else None)
# attraction_loss = attraction_loss + sds_loss
### optimize step ###
# attraction 은 그래프 열고 스케일
scaled_attraction_loss = self.opt.lambda_sd * attraction_loss
# --- PATCH: 비율 계산 및 lambda 적응 순서 수정 ---
# 1) 현재 lambda로 repulsion을 '비율 계산용'으로만 한 번 스케일(그래프 X)
# Calculate ratio_pct for both adaptive_lambda True and False cases
with torch.no_grad():
scaled_repulsion_for_ratio = (self.opt.lambda_repulsion * repulsion_loss).detach()
denom_val = max(float(abs(scaled_attraction_loss.detach().item())), 1e-8)
ratio_pct = 100.0 * float(abs(scaled_repulsion_for_ratio.item())) / denom_val
if self.opt.adaptive_lambda:
if (
self.opt.repulsion_type != 'wo'
and abs(repulsion_loss.detach().item()) > 1e-12
and self.opt.adaptive_lambda
and torch.isfinite(repulsion_loss).item()
and torch.isfinite(scaled_attraction_loss).item()
and denom_val > self.repctl.get("denom_floor", 1e-6) # ← 추가
):
self.adaptive_lambda_repulsion(ratio_pct)
# 2) (필요 시 업데이트된) lambda로 repulsion을 '학습용'으로 다시 스케일 (그래프 O)
scaled_repulsion_loss = self.opt.lambda_repulsion * repulsion_loss
# 최종 loss
total_loss = scaled_attraction_loss + scaled_repulsion_loss
# --- 안전 가드 ---
if not torch.isfinite(total_loss):
print(f"[WARN] non-finite total_loss at step {self.step}: "
f"attract={scaled_attraction_loss.item():.4e}, "
f"repulse={scaled_repulsion_loss.item():.4e}")
# 그래디언트 초기화만 하고 스킵
for j in range(self.opt.num_particles):
self.optimizers[j].zero_grad(set_to_none=True)
return
# 2. Backward pass (Compute gradients)
total_loss.backward()
# 3. Optimize step (Update parameters)
for j in range(self.opt.num_particles):
self.optimizers[j].step()
# densify and prune (after backward pass so gradients are available)
for j in range(self.opt.num_particles):
if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
viewspace_point_tensor, visibility_filter, radii = outputs[j]["viewspace_points"], outputs[j]["visibility_filter"], outputs[j]["radii"]
self.renderers[j].gaussians.max_radii2D[visibility_filter] = torch.max(self.renderers[j].gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
self.renderers[j].gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if self.step % self.opt.densification_interval == 0:
self.renderers[j].gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=4, max_screen_size=1)
# if self.step % self.opt.opacity_reset_interval == 0:
# self.renderers[j].gaussians.reset_opacity()
# 4. Zero gradients (Prepare for next iteration)
for j in range(self.opt.num_particles):
self.optimizers[j].zero_grad()
#########################################################
# Log metrics and visualize
#########################################################
with torch.no_grad():
if self.opt.metrics and self.metrics_calculator is not None:
# time
if self.step==1 or self.step % self.opt.efficiency_interval == 0:
if torch.cuda.is_available():
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
# memory usage
memory_allocated = torch.cuda.memory_allocated() / (1024 ** 2) # Convert to MB
max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 2) # Convert to MB
else:
t = None
memory_allocated = None
max_memory_allocated = None
# log
self.metrics_calculator.log_efficiency(
step=self.step,
efficiency= {
"time": t,
"memory_allocated_mb": memory_allocated,
"max_memory_allocated_mb": max_memory_allocated,
},
)
# losses
if self.step==1 or self.step % self.opt.losses_interval == 0:
attraction_loss_val = attraction_loss.item()
repulsion_loss_val = repulsion_loss.item()
scaled_attraction_loss_val = self.opt.lambda_sd * attraction_loss_val
scaled_repulsion_loss_val = self.opt.lambda_repulsion * repulsion_loss_val
total_loss_val = total_loss.item()
if self.opt.adaptive_lambda:
ratio_pct_after = 100.0 * float(abs(scaled_repulsion_loss_val)) / max(float(abs(scaled_attraction_loss_val)), 1e-8)
else:
ratio_pct_after = ratio_pct
# log
self.metrics_calculator.log_losses(
step=self.step,
losses= {
"attraction_loss": attraction_loss_val,
"repulsion_loss": repulsion_loss_val,
"scaled_attraction_loss": scaled_attraction_loss_val,
"scaled_repulsion_loss": scaled_repulsion_loss_val,
"total_loss": total_loss_val,
# "scaled_repulsion_loss_ratio": abs(scaled_repulsion_loss_val / scaled_attraction_loss_val) * 100,
# add
"scaled_repulsion_loss_ratio_before": ratio_pct,
"scaled_repulsion_loss_ratio_after": ratio_pct_after,
"lambda_repulsion": float(self.opt.lambda_repulsion),
"rep_ratio_pct_ema": float(self._rep_ratio_ema) if self._rep_ratio_ema is not None else None,
},
)
# quantitative metrics
if self.step==1 or self.step % self.opt.quantitative_metrics_interval == 0 and self.visualizer is not None:
# Update visualizer with current training state
self.visualizer.update_renderers(self.renderers)
multi_view_images = self.visualizer.visualize_all_particles_in_multi_viewpoints(self.step, visualize_multi_viewpoints=self.opt.visualize_multi_viewpoints, save_iid=self.opt.save_iid) # [V, N, 3, H, W]
# Clean up visualizer renderers to free memory
self.visualizer.cleanup_renderers()
# fidelity
fidelity_mean, fidelity_std = self.metrics_calculator.compute_clip_fidelity_in_multi_viewpoints_stats(multi_view_images)
# compute inter-particle diversity
inter_particle_diversity_mean, inter_particle_diversity_std = self.metrics_calculator.compute_inter_particle_diversity_in_multi_viewpoints_stats(multi_view_images)
# Compute cross-view consistency
cross_view_consistency_mean, cross_view_consistency_std = self.metrics_calculator.compute_cross_view_consistency_stats(multi_view_images)
if self.opt.enable_lpips:
# LPIPS (inter-sample and cross-view consistency)
lpips_inter_mean, lpips_inter_std, lpips_consistency_mean, lpips_consistency_std = self.metrics_calculator.compute_lpips_inter_and_consistency(multi_view_images)
# log
self.metrics_calculator.log_quantitative_metrics(
step=self.step,
metrics= {
"fidelity_mean": fidelity_mean,
"fidelity_std": fidelity_std,
"inter_particle_diversity_mean": inter_particle_diversity_mean,
"inter_particle_diversity_std": inter_particle_diversity_std,
"cross_view_consistency_mean": cross_view_consistency_mean,
"cross_view_consistency_std": cross_view_consistency_std,
# LPIPS
"lpips_inter_mean": lpips_inter_mean if self.opt.enable_lpips else None,
"lpips_inter_std": lpips_inter_std if self.opt.enable_lpips else None,
"lpips_consistency_mean": lpips_consistency_mean if self.opt.enable_lpips else None,
"lpips_consistency_std": lpips_consistency_std if self.opt.enable_lpips else None,
}
)
# visualize
if self.opt.visualize and self.visualizer is not None:
# save rendered images (save at the end of each interval)
if self.opt.save_rendered_images and (self.step==1 or self.step % self.opt.save_rendered_images_interval == 0):
self.visualizer.update_renderers(self.renderers)
self.visualizer.save_rendered_images(self.step, images)
self.visualizer.cleanup_renderers()
if self.opt.visualize_fixed_viewpoint and (self.step==1 or self.step % self.opt.visualize_fixed_viewpoint_interval == 0):
self.visualizer.update_renderers(self.renderers)
self.visualizer.visualize_fixed_viewpoint(self.step)
self.visualizer.cleanup_renderers()
# # save model at the best step
# if self.step == self.opt.best_step:
# print(f"[INFO] best step reached: {self.step}")
# # Multi-viewpoints for 30 fps video (save at the end of training)
# if self.opt.video_snapshot:
# # Update visualizer with final training state
# self.visualizer.update_renderers(self.renderers)
# self.visualizer.visualize_all_particles_in_multi_viewpoints(self.step, num_views=120, save_iid=True) # 360 / 120 for 30 fps
# # Clean up visualizer renderers to free memory
# self.visualizer.cleanup_renderers()
# # save model
# if self.opt.save_model:
# for j in range(self.opt.num_particles):
# self.save_model(mode='model', particle_id=j, step=self.step)
# self.save_model(mode='geo+tex', particle_id=j, step=self.step)
# Periodic GPU memory cleanup
if self.step % self.opt.efficiency_interval == 0:
try:
del images
except NameError:
pass
try:
del outputs
except NameError:
pass
try:
del score_gradients
except NameError:
pass
try:
del latents
except NameError:
pass
try:
del features
except NameError:
pass
try:
del kernel
except NameError:
pass
try:
del kernel_grad
except NameError:
pass
try:
del scaled_attraction_loss
except NameError:
pass
try:
del scaled_repulsion_loss
except NameError:
pass
try:
del total_loss
except NameError:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@torch.no_grad()
def save_model(self, mode='geo', texture_size=1024, particle_id=0, step=None):
path = os.path.join(self.opt.outdir, f'saved_models')
os.makedirs(path, exist_ok=True)
if mode == 'geo':
path = os.path.join(path, f'step_{step}_particle_{particle_id}_mesh.ply')
mesh = self.renderers[particle_id].gaussians.extract_mesh(path, self.opt.density_thresh)
mesh.write_ply(path)
# Cleanup heavy objects and CUDA cache
try:
del mesh
except NameError:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif mode == 'geo+tex':
path = os.path.join(path, f'step_{step}_particle_{particle_id}_mesh.' + self.opt.mesh_format)
mesh = self.renderers[particle_id].gaussians.extract_mesh(path, self.opt.density_thresh)
# perform texture extraction
print(f"[INFO] particle {particle_id} unwrap uv...")
h = w = texture_size
mesh.auto_uv()
mesh.auto_normal()
albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
render_resolution = 512
import nvdiffrast.torch as dr
if not self.opt.force_cuda_rast and (not self.opt.gui or os.name == 'nt'):
glctx = dr.RasterizeGLContext()
else:
glctx = dr.RasterizeCudaContext()
for ver, hor in zip(vers, hors):
# render image
pose = orbit_camera(ver, hor, self.cam.radius)
cur_cam = MiniCam(
pose,
render_resolution,
render_resolution,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
)
cur_out = self.renderers[particle_id].render(cur_cam)
rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
# get coordinate in texture image
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
v_clip = v_cam @ proj.T
rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
depth = depth.squeeze(0) # [H, W, 1]
alpha = (rast[0, ..., 3:] > 0).float()
uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
# use normal to produce a back-project mask
normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
normal = safe_normalize(normal[0])
# rotated normal (where [0, 0, 1] always faces camera)
rot_normal = normal @ pose[:3, :3]
viewcos = rot_normal[..., [2]]
mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
mask = mask.view(-1)
uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
# update texture image
cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
h, w,
uvs[..., [1, 0]] * 2 - 1,
rgbs,
min_resolution=256,
return_count=True,
)
# albedo += cur_albedo
# cnt += cur_cnt
mask = cnt.squeeze(-1) < 0.1
albedo[mask] += cur_albedo[mask]
cnt[mask] += cur_cnt[mask]
mask = cnt.squeeze(-1) > 0
albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
mask = mask.view(h, w)
albedo = albedo.detach().cpu().numpy()
mask = mask.detach().cpu().numpy()
# dilate texture
from sklearn.neighbors import NearestNeighbors
from scipy.ndimage import binary_dilation, binary_erosion
inpaint_region = binary_dilation(mask, iterations=32)
inpaint_region[mask] = 0
search_region = mask.copy()
not_search_region = binary_erosion(search_region, iterations=3)
search_region[not_search_region] = 0
search_coords = np.stack(np.nonzero(search_region), axis=-1)
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
search_coords
)
_, indices = knn.kneighbors(inpaint_coords)
albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
mesh.albedo = torch.from_numpy(albedo).to(self.device)
mesh.write(path)
# Cleanup heavy objects and CUDA cache
try:
del albedo
except NameError:
pass
try:
del cnt
except NameError:
pass
try:
del mesh
except NameError:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
path = os.path.join(path, f'step_{step}_particle_{particle_id}_model.ply')
self.renderers[particle_id].gaussians.save_ply(path)
print(f"[INFO] particle {particle_id} save model to {path}.")
# Cleanup CUDA cache after saving
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# no gui mode
def train(self, iters=500):
if iters > 0:
self.prepare_train()
for i in tqdm.trange(iters):
self.train_step()
# do a last prune
for j in range(self.opt.num_particles):
self.renderers[j].gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)
# Multi-viewpoints for 30 fps video (save at the end of training)
if self.opt.video_snapshot:
# Update visualizer with final training state
self.visualizer.update_renderers(self.renderers)
self.visualizer.visualize_all_particles_in_multi_viewpoints(self.step, num_views=120, visualize_multi_viewpoints=True, save_iid=True) # 360 / 120 for 30 fps
# Clean up visualizer renderers to free memory
self.visualizer.cleanup_renderers()
# save model
if self.opt.save_model:
for j in range(self.opt.num_particles):
self.save_model(mode='model', particle_id=j, step=self.step)
self.save_model(mode='geo+tex', particle_id=j, step=self.step)
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to the yaml config file")
args, extras = parser.parse_known_args()
# override default config from cli
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
gui = GUI(opt)
gui.train(opt.iters)