Skip to content

Commit 0b4dddf

Browse files
authored
Support 2dgs with the latest viewer (#707)
* Update simple_viewer.py * Update gsplat_viewer.py * Create gsplat_viewer_2dgs.py * Update simple_trainer_2dgs.py * Create simple_viewer_2dgs.py * Update simple_viewer_2dgs.py * should use render_mode="RGB+ED"
1 parent 5b842f0 commit 0b4dddf

File tree

4 files changed

+454
-21
lines changed

4 files changed

+454
-21
lines changed

examples/gsplat_viewer_2dgs.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import viser
2+
from pathlib import Path
3+
from typing import Literal
4+
from typing import Tuple, Callable
5+
from nerfview import Viewer, RenderTabState
6+
7+
8+
class GsplatRenderTabState(RenderTabState):
9+
# non-controlable parameters
10+
total_gs_count: int = 0
11+
rendered_gs_count: int = 0
12+
13+
# controlable parameters
14+
max_sh_degree: int = 5
15+
near_plane: float = 1e-2
16+
far_plane: float = 1e2
17+
radius_clip: float = 0.0
18+
eps2d: float = 0.3
19+
backgrounds: Tuple[float, float, float] = (0.0, 0.0, 0.0)
20+
render_mode: Literal[
21+
"rgb", "depth(accumulated)", "depth(expected)", "alpha"
22+
] = "rgb"
23+
normalize_nearfar: bool = False
24+
inverse: bool = False
25+
colormap: Literal[
26+
"turbo", "viridis", "magma", "inferno", "cividis", "gray"
27+
] = "turbo"
28+
29+
30+
class GsplatViewer(Viewer):
31+
"""
32+
Viewer for gsplat 2dgs.
33+
"""
34+
35+
def __init__(
36+
self,
37+
server: viser.ViserServer,
38+
render_fn: Callable,
39+
output_dir: Path,
40+
mode: Literal["rendering", "training"] = "rendering",
41+
):
42+
super().__init__(server, render_fn, output_dir, mode)
43+
server.gui.set_panel_label("gsplat 2dgs viewer")
44+
45+
def _init_rendering_tab(self):
46+
self.render_tab_state = GsplatRenderTabState()
47+
self._rendering_tab_handles = {}
48+
self._rendering_folder = self.server.gui.add_folder("Rendering")
49+
50+
def _populate_rendering_tab(self):
51+
server = self.server
52+
with self._rendering_folder:
53+
with server.gui.add_folder("Gsplat"):
54+
total_gs_count_number = server.gui.add_number(
55+
"Total",
56+
initial_value=self.render_tab_state.total_gs_count,
57+
disabled=True,
58+
hint="Total number of splats in the scene.",
59+
)
60+
rendered_gs_count_number = server.gui.add_number(
61+
"Rendered",
62+
initial_value=self.render_tab_state.rendered_gs_count,
63+
disabled=True,
64+
hint="Number of splats rendered.",
65+
)
66+
67+
max_sh_degree_number = server.gui.add_number(
68+
"Max SH",
69+
initial_value=self.render_tab_state.max_sh_degree,
70+
min=0,
71+
max=5,
72+
step=1,
73+
hint="Maximum SH degree used",
74+
)
75+
76+
@max_sh_degree_number.on_update
77+
def _(_) -> None:
78+
self.render_tab_state.max_sh_degree = int(
79+
max_sh_degree_number.value
80+
)
81+
self.rerender(_)
82+
83+
near_far_plane_vec2 = server.gui.add_vector2(
84+
"Near/Far",
85+
initial_value=(
86+
self.render_tab_state.near_plane,
87+
self.render_tab_state.far_plane,
88+
),
89+
min=(1e-3, 1e1),
90+
max=(1e1, 1e3),
91+
step=1e-3,
92+
hint="Near and far plane for rendering.",
93+
)
94+
95+
@near_far_plane_vec2.on_update
96+
def _(_) -> None:
97+
self.render_tab_state.near_plane = near_far_plane_vec2.value[0]
98+
self.render_tab_state.far_plane = near_far_plane_vec2.value[1]
99+
self.rerender(_)
100+
101+
radius_clip_slider = server.gui.add_number(
102+
"Radius Clip",
103+
initial_value=self.render_tab_state.radius_clip,
104+
min=0.0,
105+
max=100.0,
106+
step=1.0,
107+
hint="2D radius clip for rendering.",
108+
)
109+
110+
@radius_clip_slider.on_update
111+
def _(_) -> None:
112+
self.render_tab_state.radius_clip = radius_clip_slider.value
113+
self.rerender(_)
114+
115+
eps2d_slider = server.gui.add_number(
116+
"2D Epsilon",
117+
initial_value=self.render_tab_state.eps2d,
118+
min=0.0,
119+
max=1.0,
120+
step=0.01,
121+
hint="Epsilon added to the egienvalues of projected 2D covariance matrices.",
122+
)
123+
124+
@eps2d_slider.on_update
125+
def _(_) -> None:
126+
self.render_tab_state.eps2d = eps2d_slider.value
127+
self.rerender(_)
128+
129+
backgrounds_slider = server.gui.add_rgb(
130+
"Background",
131+
initial_value=self.render_tab_state.backgrounds,
132+
hint="Background color for rendering.",
133+
)
134+
135+
@backgrounds_slider.on_update
136+
def _(_) -> None:
137+
self.render_tab_state.backgrounds = backgrounds_slider.value
138+
self.rerender(_)
139+
140+
render_mode_dropdown = server.gui.add_dropdown(
141+
"Render Mode",
142+
("rgb", "depth", "normal", "alpha"),
143+
initial_value=self.render_tab_state.render_mode,
144+
hint="Render mode to use.",
145+
)
146+
147+
@render_mode_dropdown.on_update
148+
def _(_) -> None:
149+
if "depth" in render_mode_dropdown.value:
150+
normalize_nearfar_checkbox.disabled = False
151+
inverse_checkbox.disabled = False
152+
else:
153+
normalize_nearfar_checkbox.disabled = True
154+
inverse_checkbox.disabled = True
155+
self.render_tab_state.render_mode = render_mode_dropdown.value
156+
self.rerender(_)
157+
158+
normalize_nearfar_checkbox = server.gui.add_checkbox(
159+
"Normalize Near/Far",
160+
initial_value=self.render_tab_state.normalize_nearfar,
161+
disabled=True,
162+
hint="Normalize depth with near/far plane.",
163+
)
164+
165+
@normalize_nearfar_checkbox.on_update
166+
def _(_) -> None:
167+
self.render_tab_state.normalize_nearfar = (
168+
normalize_nearfar_checkbox.value
169+
)
170+
self.rerender(_)
171+
172+
inverse_checkbox = server.gui.add_checkbox(
173+
"Inverse",
174+
initial_value=self.render_tab_state.inverse,
175+
disabled=True,
176+
hint="Inverse the depth.",
177+
)
178+
179+
@inverse_checkbox.on_update
180+
def _(_) -> None:
181+
self.render_tab_state.inverse = inverse_checkbox.value
182+
self.rerender(_)
183+
184+
colormap_dropdown = server.gui.add_dropdown(
185+
"Colormap",
186+
("turbo", "viridis", "magma", "inferno", "cividis", "gray"),
187+
initial_value=self.render_tab_state.colormap,
188+
hint="Colormap used for rendering depth/alpha.",
189+
)
190+
191+
@colormap_dropdown.on_update
192+
def _(_) -> None:
193+
self.render_tab_state.colormap = colormap_dropdown.value
194+
self.rerender(_)
195+
196+
self._rendering_tab_handles.update(
197+
{
198+
"total_gs_count_number": total_gs_count_number,
199+
"rendered_gs_count_number": rendered_gs_count_number,
200+
"near_far_plane_vec2": near_far_plane_vec2,
201+
"radius_clip_slider": radius_clip_slider,
202+
"eps2d_slider": eps2d_slider,
203+
"backgrounds_slider": backgrounds_slider,
204+
"render_mode_dropdown": render_mode_dropdown,
205+
"normalize_nearfar_checkbox": normalize_nearfar_checkbox,
206+
"inverse_checkbox": inverse_checkbox,
207+
"colormap_dropdown": colormap_dropdown,
208+
}
209+
)
210+
super()._populate_rendering_tab()
211+
212+
def _after_render(self):
213+
# Update the GUI elements with current values
214+
self._rendering_tab_handles[
215+
"total_gs_count_number"
216+
].value = self.render_tab_state.total_gs_count
217+
self._rendering_tab_handles[
218+
"rendered_gs_count_number"
219+
].value = self.render_tab_state.rendered_gs_count

examples/simple_trainer_2dgs.py

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from dataclasses import dataclass, field
66
from typing import Dict, List, Literal, Optional, Tuple
7+
from pathlib import Path
78

89
import imageio
910
import nerfview
@@ -28,9 +29,10 @@
2829
rgb_to_sh,
2930
set_random_seed,
3031
)
31-
32+
from gsplat_viewer_2dgs import GsplatViewer, GsplatRenderTabState
3233
from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper
3334
from gsplat.strategy import DefaultStrategy
35+
from nerfview import CameraState, RenderTabState, apply_float_colormap
3436

3537

3638
@dataclass
@@ -375,9 +377,10 @@ def __init__(self, cfg: Config) -> None:
375377
# Viewer
376378
if not self.cfg.disable_viewer:
377379
self.server = viser.ViserServer(port=cfg.port, verbose=False)
378-
self.viewer = nerfview.Viewer(
380+
self.viewer = GsplatViewer(
379381
server=self.server,
380382
render_fn=self._viewer_render_fn,
383+
output_dir=Path(cfg.result_dir),
381384
mode="training",
382385
)
383386

@@ -507,7 +510,7 @@ def train(self):
507510
pbar = tqdm.tqdm(range(init_step, max_steps))
508511
for step in pbar:
509512
if not cfg.disable_viewer:
510-
while self.viewer.state.status == "paused":
513+
while self.viewer.state == "paused":
511514
time.sleep(0.01)
512515
self.viewer.lock.acquire()
513516
tic = time.time()
@@ -736,7 +739,9 @@ def train(self):
736739
num_train_rays_per_step * num_train_steps_per_sec
737740
)
738741
# Update the viewer state.
739-
self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec
742+
self.viewer.render_tab_state.num_train_rays_per_sec = (
743+
num_train_rays_per_sec
744+
)
740745
# Update the scene.
741746
self.viewer.update(step, num_train_rays_per_step)
742747

@@ -795,7 +800,7 @@ def eval(self, step: int):
795800
)
796801
# render_median = render_median.detach().cpu().squeeze(0).unsqueeze(-1).repeat(1, 1, 3).numpy()
797802
render_median = (
798-
render_median.detach().cpu().squeeze(0).repeat(1, 1, 3).numpy()
803+
apply_float_colormap(render_median).detach().cpu().squeeze(0).numpy()
799804
)
800805

801806
imageio.imwrite(
@@ -831,13 +836,11 @@ def eval(self, step: int):
831836
dist_min = torch.min(render_dist)
832837
render_dist = (render_dist - dist_min) / (dist_max - dist_min)
833838
render_dist = (
834-
colormap(render_dist.cpu().numpy()[0])
835-
.permute((1, 2, 0))
836-
.numpy()
837-
.astype(np.uint8)
839+
apply_float_colormap(render_dist).detach().cpu().squeeze(0).numpy()
838840
)
839841
imageio.imwrite(
840-
f"{self.render_dir}/val_{i:04d}_distortions_{step}.png", render_dist
842+
f"{self.render_dir}/val_{i:04d}_distortions_{step}.png",
843+
(render_dist * 255).astype(np.uint8),
841844
)
842845

843846
pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
@@ -930,24 +933,75 @@ def render_traj(self, step: int):
930933

931934
@torch.no_grad()
932935
def _viewer_render_fn(
933-
self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
936+
self, camera_state: CameraState, render_tab_state: RenderTabState
934937
):
935-
"""Callable function for the viewer."""
936-
W, H = img_wh
938+
assert isinstance(render_tab_state, GsplatRenderTabState)
939+
if render_tab_state.preview_render:
940+
width = render_tab_state.render_width
941+
height = render_tab_state.render_height
942+
else:
943+
width = render_tab_state.viewer_width
944+
height = render_tab_state.viewer_height
937945
c2w = camera_state.c2w
938-
K = camera_state.get_K(img_wh)
946+
K = camera_state.get_K((width, height))
939947
c2w = torch.from_numpy(c2w).float().to(self.device)
940948
K = torch.from_numpy(K).float().to(self.device)
941949

942-
render_colors, _, _, _, _, _, _ = self.rasterize_splats(
950+
(
951+
render_colors,
952+
render_alphas,
953+
render_normals,
954+
normals_from_depth,
955+
render_distort,
956+
render_median,
957+
info,
958+
) = self.rasterize_splats(
943959
camtoworlds=c2w[None],
944960
Ks=K[None],
945-
width=W,
946-
height=H,
947-
sh_degree=self.cfg.sh_degree, # active all SH degrees
948-
radius_clip=3.0, # skip GSs that have small image radius (in pixels)
961+
width=width,
962+
height=height,
963+
sh_degree=min(render_tab_state.max_sh_degree, self.cfg.sh_degree),
964+
near_plane=render_tab_state.near_plane,
965+
far_plane=render_tab_state.far_plane,
966+
radius_clip=render_tab_state.radius_clip,
967+
eps2d=render_tab_state.eps2d,
968+
render_mode="RGB+ED",
969+
backgrounds=torch.tensor([render_tab_state.backgrounds], device=self.device)
970+
/ 255.0,
949971
) # [1, H, W, 3]
950-
return render_colors[0].cpu().numpy()
972+
render_tab_state.total_gs_count = len(self.splats["means"])
973+
render_tab_state.rendered_gs_count = (info["radii"] > 0).all(-1).sum().item()
974+
975+
if render_tab_state.render_mode == "depth":
976+
# normalize depth to [0, 1]
977+
depth = render_median
978+
if render_tab_state.normalize_nearfar:
979+
near_plane = render_tab_state.near_plane
980+
far_plane = render_tab_state.far_plane
981+
else:
982+
near_plane = depth.min()
983+
far_plane = depth.max()
984+
depth_norm = (depth - near_plane) / (far_plane - near_plane + 1e-10)
985+
depth_norm = torch.clip(depth_norm, 0, 1)
986+
if render_tab_state.inverse:
987+
depth_norm = 1 - depth_norm
988+
renders = (
989+
apply_float_colormap(depth_norm, render_tab_state.colormap)
990+
.cpu()
991+
.numpy()
992+
)
993+
elif render_tab_state.render_mode == "normal":
994+
render_normals = render_normals * 0.5 + 0.5 # normalize to [0, 1]
995+
renders = render_normals.cpu().numpy()
996+
elif render_tab_state.render_mode == "alpha":
997+
alpha = render_alphas[0, ..., 0:1]
998+
renders = (
999+
apply_float_colormap(alpha, render_tab_state.colormap).cpu().numpy()
1000+
)
1001+
else:
1002+
render_colors = render_colors[0, ..., 0:3].clamp(0, 1)
1003+
renders = render_colors.cpu().numpy()
1004+
return renders
9511005

9521006

9531007
def main(cfg: Config):

0 commit comments

Comments
 (0)