Skip to content

Commit e84c7c8

Browse files
committed
fixed bugs for PSNR and env, and remove redundant code
1 parent ac43797 commit e84c7c8

File tree

7 files changed

+8
-88
lines changed

7 files changed

+8
-88
lines changed

internal/metrics/vanilla_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def adapter(pred, gt):
4040

4141

4242
def setup(self, stage: str, pl_module):
43-
self.psnr = PeakSignalNoiseRatio()
43+
self.psnr = PeakSignalNoiseRatio(data_range=1.)
4444
self.no_state_dict_models["lpips"] = LearnedPerceptualImagePatchSimilarity(normalize=True, net_type=self.config.lpips_net_type)
4545

4646
self.lambda_dssim = self.config.lambda_dssim

internal/renderers/gsplat_camera_opt.py

Lines changed: 2 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
4242
class ModelConfig:
4343

4444
n_cameras: int = -1
45-
pose_opt_type: Literal["sfm", "mlp", "7dmlp"] = "sfm"
45+
pose_opt_type: Literal["sfm", "mlp"] = "sfm"
4646
cam_scale: float = 1.0
47-
scale: float = 1e-3 # Used for 7dmlp
4847
mlp_width: int = 64
4948
mlp_depth: int = 2
5049

@@ -58,7 +57,6 @@ class OptimizationConfig:
5857
shceduler_type: Literal["step", "cosine", "none"] = "none"
5958
eps: float = 1e-15
6059
max_steps: int = 30_000
61-
opt_test: bool = False # TODO: remove it
6260

6361
class CameraOptModule(nn.Module):
6462
"""Camera pose optimization module."""
@@ -166,80 +164,7 @@ def forward(self, camtoworlds: torch.Tensor, embed_ids: torch.Tensor) -> torch.T
166164
transform[..., :3, 3] = dx * self.cam_scale
167165

168166
return torch.matmul(camtoworlds, transform)
169-
170-
class CameraOptModule7dMLP(torch.nn.Module):
171-
"""Camera pose optimization module using MLP."""
172-
173-
def __init__(self, n: int, mlp_width: int = 256, mlp_depth: int = 2, scale: float = 1e-6):
174-
super().__init__()
175-
# Identity rotation in 6D representation
176-
self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]))
177-
178-
# Initial embeddings for each camera
179-
self.num_cams = n
180-
181-
# MLP layers
182-
activation = torch.nn.ELU(inplace=True)
183-
layers = []
184-
layers.append(torch.nn.Linear(7, mlp_width))
185-
layers.append(activation)
186-
for _ in range(mlp_depth - 1):
187-
layers.append(torch.nn.Linear(mlp_width, mlp_width))
188-
layers.append(activation)
189-
# Output layer produces 9D adjustments (3D position + 6D rotation)
190-
layers.append(torch.nn.Linear(mlp_width, 6))
191-
self.mlp = torch.nn.Sequential(*layers)
192-
193-
self.scale = scale
194-
195-
def zero_init(self):
196-
# torch.nn.init.zeros_(self.embeds.weight)
197-
#torch.nn.init.normal_(self.embeds.weight)
198-
# Also initialize the last layer of MLP with small weights
199-
# torch.nn.init.zeros_(self.mlp[-1].weight)
200-
# torch.nn.init.zeros_(self.mlp[-1].bias)
201-
pass
202-
203-
def random_init(self, std: float):
204-
# torch.nn.init.normal_(self.embeds.weight, std=std)
205-
# Initialize the last layer of MLP with small weights
206-
torch.nn.init.normal_(self.mlp[-1].weight, std=std)
207-
torch.nn.init.normal_(self.mlp[-1].bias, std=std)
208-
209-
def forward(self, camtoworlds: torch.Tensor, embed_ids: torch.Tensor) -> torch.Tensor:
210-
"""Adjust camera pose based on MLP outputs with SGLD noise.
211-
212-
Args:
213-
camtoworlds: (..., 4, 4)
214-
embed_ids: (...,)
215-
216-
Returns:
217-
updated camtoworlds: (..., 4, 4)
218-
"""
219-
assert camtoworlds.shape[:-2] == embed_ids.shape
220-
if camtoworlds.ndim == 2:
221-
camtoworlds = camtoworlds.unsqueeze(0)
222-
if embed_ids.ndim == 0:
223-
embed_ids = embed_ids.unsqueeze(0)
224-
batch_shape = camtoworlds.shape[:-2]
225-
226-
# Get embeddings and process through MLP with noise
227-
r_init = rotation_matrix_to_axis_angle(camtoworlds[..., :3, :3])
228-
t_init = camtoworlds[..., :3, 3]
229-
230-
mlp_input = torch.cat((embed_ids[..., None], r_init, t_init), dim=-1) # (..., 7)
231-
232-
out = self.mlp(mlp_input) * self.scale
233-
234-
r = out[..., :3] + r_init
235-
t = out[..., 3:] + t_init
236-
R = axis_angle_to_rotation_matrix(r)
237-
238-
camtoworlds_corrected = torch.eye(4, device=camtoworlds.device).repeat((*batch_shape, 1, 1))
239-
camtoworlds_corrected[..., :3, :3] = R
240-
camtoworlds_corrected[..., :3, 3] = t
241-
242-
return camtoworlds_corrected.squeeze()
167+
243168

244169
@dataclass
245170
class GSplatCameraOptRenderer(GSplatV1Renderer):
@@ -281,13 +206,6 @@ def _setup_model(self, device=None):
281206
mlp_depth=self.config.model.mlp_depth,
282207
cam_scale=self.config.model.cam_scale
283208
)
284-
elif self.config.model.pose_opt_type == "7dmlp":
285-
self.model = CameraOptModule7dMLP(
286-
n=self.config.model.n_cameras,
287-
mlp_width=self.config.model.mlp_width,
288-
mlp_depth=self.config.model.mlp_depth,
289-
scale=self.config.model.scale
290-
)
291209
else:
292210
self.model = CameraOptModule(self.config.model.n_cameras)
293211

notebooks/preprocess.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@
864864
],
865865
"source": [
866866
"from torchmetrics import PeakSignalNoiseRatio\n",
867-
"psnr = PeakSignalNoiseRatio().to(rgb.device)\n",
867+
"psnr = PeakSignalNoiseRatio(data_range=1.).to(rgb.device)\n",
868868
"psnr(rgb, results[\"render\"].permute(1, 2, 0))"
869869
],
870870
"metadata": {

notebooks/render.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@
406406
],
407407
"source": [
408408
"from torchmetrics import PeakSignalNoiseRatio\n",
409-
"psnr = PeakSignalNoiseRatio().to(gt.device)\n",
409+
"psnr = PeakSignalNoiseRatio(data_range=1.).to(gt.device)\n",
410410
"psnr(results[\"render\"], gt)"
411411
]
412412
}

notebooks/rotate_shs.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,7 @@
978978
"from internal.renderers.gsplat_renderer import GSPlatRenderer\n",
979979
"from internal.cameras.cameras import Cameras, CameraType\n",
980980
"from torchmetrics.image.psnr import PeakSignalNoiseRatio\n",
981-
"psnr = PeakSignalNoiseRatio().to(model.get_xyz.device)"
981+
"psnr = PeakSignalNoiseRatio(data_range=1.).to(model.get_xyz.device)"
982982
],
983983
"metadata": {
984984
"collapsed": false,

requirements/gsplat.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
kornia
12
git+https://github.com/yzslab/gsplat.git@58f3772541b6fb55e3219b36cd2b64be0584645c

requirements/lightning23.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
lightning[pytorch-extra]==2.3.*
22
pytorch-lightning==2.3.*
3+
bitsandbytes==0.45.*
34
-r common.txt

0 commit comments

Comments
 (0)