1212from pathlib import Path
1313from gsplat ._helper import load_test_data
1414from gsplat .distributed import cli
15- from gsplat .rendering import rasterization
15+ from gsplat .rendering import rasterization , rasterization_3dcs
1616
1717from nerfview import CameraState , RenderTabState , apply_float_colormap
1818from gsplat_viewer import GsplatViewer , GsplatRenderTabState
@@ -101,55 +101,81 @@ def main(local_rank: int, world_rank, world_size: int, args):
101101 )
102102 else :
103103 means , quats , scales , opacities , sh0 , shN = [], [], [], [], [], []
104- for ckpt_path in args .ckpt :
105- ckpt = torch .load (ckpt_path , map_location = device )["splats" ]
106- means .append (ckpt ["means" ])
107- quats .append (F .normalize (ckpt ["quats" ], p = 2 , dim = - 1 ))
108- scales .append (torch .exp (ckpt ["scales" ]))
109- opacities .append (torch .sigmoid (ckpt ["opacities" ]))
110- sh0 .append (ckpt ["sh0" ])
111- shN .append (ckpt ["shN" ])
112- means = torch .cat (means , dim = 0 )
113- quats = torch .cat (quats , dim = 0 )
114- scales = torch .cat (scales , dim = 0 )
115- opacities = torch .cat (opacities , dim = 0 )
116- sh0 = torch .cat (sh0 , dim = 0 )
117- shN = torch .cat (shN , dim = 0 )
118- colors = torch .cat ([sh0 , shN ], dim = - 2 )
119- sh_degree = int (math .sqrt (colors .shape [- 2 ]) - 1 )
120104
121- # # crop
122- # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
123- # edges = aabb[3:] - aabb[:3]
124- # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
125- # sel = torch.where(sel)[0]
126- # means, quats, scales, colors, opacities = (
127- # means[sel],
128- # quats[sel],
129- # scales[sel],
130- # colors[sel],
131- # opacities[sel],
132- # )
105+ convex_points , delta , sigma , num_points_per_convex , cumsum_of_points_per_convex = [], [], [], [], []
106+ if args .backend == "3dcs" :
107+ for ckpt_path in args .ckpt :
108+ hyperparam = torch .load (os .path .join (ckpt_path , "hyperparameters.pt" ), map_location = device , weights_only = False )
109+ pc = torch .load (os .path .join (ckpt_path , "point_cloud_state_dict.pt" ), map_location = device , weights_only = False )
110+ convex_points .append (pc ['convex_points' ])
111+ delta .append (torch .exp (pc ['delta' ]))
112+ sigma .append (torch .exp (pc ['sigma' ]))
113+ opacities .append (torch .sigmoid (pc ["opacity" ]).squeeze ())
114+ num_points_per_convex .append (torch .tensor ([6 ]))
115+ cumsum_of_points_per_convex .append (hyperparam ["cumsum_of_points_per_convex" ])
116+ sh0 .append (pc ["features_dc" ])
117+ shN .append (pc ["features_rest" ])
118+ convex_points = torch .cat (convex_points , dim = 0 )
119+ delta = torch .cat (delta , dim = 0 )
120+ sigma = torch .cat (sigma , dim = 0 )
121+ num_points_per_convex = torch .cat (num_points_per_convex , dim = 0 )
122+ cumsum_of_points_per_convex = torch .cat (cumsum_of_points_per_convex , dim = 0 )
123+ opacities = torch .cat (opacities , dim = 0 )
124+ sh0 = torch .cat (sh0 , dim = 0 )
125+ shN = torch .cat (shN , dim = 0 )
126+ colors = torch .cat ([sh0 , shN ], dim = - 2 )
127+ sh_degree = int (pc ["active_sh_degree" ])
128+ print ("Number of 3D convexes:" , convex_points .shape [0 ]* convex_points .shape [1 ])
129+ else :
130+ for ckpt_path in args .ckpt :
131+ ckpt = torch .load (ckpt_path , map_location = device )["splats" ]
132+ means .append (ckpt ["means" ])
133+ quats .append (F .normalize (ckpt ["quats" ], p = 2 , dim = - 1 ))
134+ scales .append (torch .exp (ckpt ["scales" ]))
135+ opacities .append (torch .sigmoid (ckpt ["opacities" ]))
136+ sh0 .append (ckpt ["sh0" ])
137+ shN .append (ckpt ["shN" ])
138+ means = torch .cat (means , dim = 0 )
139+ quats = torch .cat (quats , dim = 0 )
140+ scales = torch .cat (scales , dim = 0 )
141+ opacities = torch .cat (opacities , dim = 0 )
142+ sh0 = torch .cat (sh0 , dim = 0 )
143+ shN = torch .cat (shN , dim = 0 )
144+ colors = torch .cat ([sh0 , shN ], dim = - 2 )
145+ sh_degree = int (math .sqrt (colors .shape [- 2 ]) - 1 )
146+
147+ # # crop
148+ # aabb = torch.tensor((-1.0, -1.0, -1.0, 1.0, 1.0, 0.7), device=device)
149+ # edges = aabb[3:] - aabb[:3]
150+ # sel = ((means >= aabb[:3]) & (means <= aabb[3:])).all(dim=-1)
151+ # sel = torch.where(sel)[0]
152+ # means, quats, scales, colors, opacities = (
153+ # means[sel],
154+ # quats[sel],
155+ # scales[sel],
156+ # colors[sel],
157+ # opacities[sel],
158+ # )
133159
134- # # repeat the scene into a grid (to mimic a large-scale setting)
135- # repeats = args.scene_grid
136- # gridx, gridy = torch.meshgrid(
137- # [
138- # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
139- # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
140- # ],
141- # indexing="ij",
142- # )
143- # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
144- # -1, 3
145- # )
146- # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
147- # means = means.reshape(-1, 3)
148- # quats = quats.repeat(repeats**2, 1)
149- # scales = scales.repeat(repeats**2, 1)
150- # colors = colors.repeat(repeats**2, 1, 1)
151- # opacities = opacities.repeat(repeats**2)
152- print ("Number of Gaussians:" , len (means ))
160+ # # repeat the scene into a grid (to mimic a large-scale setting)
161+ # repeats = args.scene_grid
162+ # gridx, gridy = torch.meshgrid(
163+ # [
164+ # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
165+ # torch.arange(-(repeats // 2), repeats // 2 + 1, device=device),
166+ # ],
167+ # indexing="ij",
168+ # )
169+ # grid = torch.stack([gridx, gridy, torch.zeros_like(gridx)], dim=-1).reshape(
170+ # -1, 3
171+ # )
172+ # means = means[None, :, :] + grid[:, None, :] * edges[None, None, :]
173+ # means = means.reshape(-1, 3)
174+ # quats = quats.repeat(repeats**2, 1)
175+ # scales = scales.repeat(repeats**2, 1)
176+ # colors = colors.repeat(repeats**2, 1, 1)
177+ # opacities = opacities.repeat(repeats**2)
178+ print ("Number of Gaussians:" , len (means ))
153179
154180 # register and open viewer
155181 @torch .no_grad ()
@@ -174,34 +200,74 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
174200 "alpha" : "RGB" ,
175201 }
176202
177- render_colors , render_alphas , info = rasterization (
178- means , # [N, 3]
179- quats , # [N, 4]
180- scales , # [N, 3]
181- opacities , # [N]
182- colors , # [N, S, 3]
183- viewmat [None ], # [1, 4, 4]
184- K [None ], # [1, 3, 3]
185- width ,
186- height ,
187- sh_degree = (
188- min (render_tab_state .max_sh_degree , sh_degree )
189- if sh_degree is not None
190- else None
191- ),
192- near_plane = render_tab_state .near_plane ,
193- far_plane = render_tab_state .far_plane ,
194- radius_clip = render_tab_state .radius_clip ,
195- eps2d = render_tab_state .eps2d ,
196- backgrounds = torch .tensor ([render_tab_state .backgrounds ], device = device )
197- / 255.0 ,
198- render_mode = RENDER_MODE_MAP [render_tab_state .render_mode ],
199- rasterize_mode = render_tab_state .rasterize_mode ,
200- camera_model = render_tab_state .camera_model ,
201- packed = False ,
202- )
203- render_tab_state .total_gs_count = len (means )
204- render_tab_state .rendered_gs_count = (info ["radii" ] > 0 ).all (- 1 ).sum ().item ()
203+ if args .backend == "gsplat" :
204+ rasterization_fn = rasterization
205+ elif args .backend == "3dcs" :
206+ rasterization_fn = rasterization_3dcs
207+ elif args .backend == "inria" :
208+ from gsplat import rasterization_inria_wrapper
209+
210+ rasterization_fn = rasterization_inria_wrapper
211+ else :
212+ raise ValueError
213+
214+ if args .backend == "3dcs" :
215+ render_colors , render_alphas , info = rasterization_fn (
216+ convex_points ,
217+ delta ,
218+ sigma ,
219+ num_points_per_convex ,
220+ cumsum_of_points_per_convex ,
221+ opacities ,
222+ colors ,
223+ viewmat [None ], # [1, 4, 4]
224+ K [None ], # [1, 3, 3]
225+ width ,
226+ height ,
227+ packed = False ,
228+ sh_degree = (
229+ min (render_tab_state .max_sh_degree , sh_degree )
230+ if sh_degree is not None
231+ else None
232+ ),
233+ near_plane = render_tab_state .near_plane ,
234+ far_plane = render_tab_state .far_plane ,
235+ radius_clip = render_tab_state .radius_clip ,
236+ eps2d = render_tab_state .eps2d ,
237+ backgrounds = torch .tensor ([render_tab_state .backgrounds ], device = device )
238+ / 255.0 ,
239+ render_mode = RENDER_MODE_MAP [render_tab_state .render_mode ],
240+ rasterize_mode = render_tab_state .rasterize_mode ,
241+ camera_model = render_tab_state .camera_model ,
242+ )
243+ else :
244+ render_colors , render_alphas , info = rasterization (
245+ means , # [N, 3]
246+ quats , # [N, 4]
247+ scales , # [N, 3]
248+ opacities , # [N]
249+ colors , # [N, S, 3]
250+ viewmat [None ], # [1, 4, 4]
251+ K [None ], # [1, 3, 3]
252+ width ,
253+ height ,
254+ sh_degree = (
255+ min (render_tab_state .max_sh_degree , sh_degree )
256+ if sh_degree is not None
257+ else None
258+ ),
259+ near_plane = render_tab_state .near_plane ,
260+ far_plane = render_tab_state .far_plane ,
261+ radius_clip = render_tab_state .radius_clip ,
262+ eps2d = render_tab_state .eps2d ,
263+ backgrounds = torch .tensor ([render_tab_state .backgrounds ], device = device )
264+ / 255.0 ,
265+ render_mode = RENDER_MODE_MAP [render_tab_state .render_mode ],
266+ rasterize_mode = render_tab_state .rasterize_mode ,
267+ camera_model = render_tab_state .camera_model ,
268+ )
269+ render_tab_state .total_gs_count = len (means )
270+ render_tab_state .rendered_gs_count = (info ["radii" ] > 0 ).all (- 1 ).sum ().item ()
205271
206272 if render_tab_state .render_mode == "rgb" :
207273 # colors represented with sh are not guranteed to be in [0, 1]
@@ -267,6 +333,12 @@ def viewer_render_fn(camera_state: CameraState, render_tab_state: RenderTabState
267333 parser .add_argument (
268334 "--ckpt" , type = str , nargs = "+" , default = None , help = "path to the .pt file"
269335 )
336+ parser .add_argument (
337+ "--ply" , type = str , nargs = "+" , default = None , help = "path to the .ply file"
338+ )
339+ parser .add_argument (
340+ "--backend" , type = str , default = "gsplat" , choices = ["gsplat" , "3dcs" , "inria" ], help = "backend to use for rendering" ,
341+ )
270342 parser .add_argument (
271343 "--port" , type = int , default = 8080 , help = "port for the viewer server"
272344 )
0 commit comments