Skip to content

Commit 4509817

Browse files
authored
fix 3dgut sample viewer (#734)
* fix viewer * thin_prism_coeffs with shape 4 * fix comment * fix remaining comment
1 parent 1dc7b21 commit 4509817

File tree

9 files changed

+21
-21
lines changed

9 files changed

+21
-21
lines changed

gsplat/cuda/_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def rasterize_to_pixels_eval3d(
655655
# distortion
656656
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
657657
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
658-
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 2]
658+
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
659659
# rolling shutter
660660
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
661661
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
@@ -714,7 +714,7 @@ def rasterize_to_pixels_eval3d(
714714
tangential_coeffs = tangential_coeffs.contiguous()
715715

716716
if thin_prism_coeffs is not None:
717-
assert thin_prism_coeffs.shape == batch_dims + (C, 2), thin_prism_coeffs.shape
717+
assert thin_prism_coeffs.shape == batch_dims + (C, 4), thin_prism_coeffs.shape
718718
thin_prism_coeffs = thin_prism_coeffs.contiguous()
719719

720720
if viewmats_rs is not None:
@@ -1129,7 +1129,7 @@ def fully_fused_projection_with_ut(
11291129
# distortion
11301130
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
11311131
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
1132-
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 2]
1132+
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
11331133
# rolling shutter
11341134
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
11351135
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
@@ -1159,7 +1159,7 @@ def fully_fused_projection_with_ut(
11591159
if tangential_coeffs is not None:
11601160
assert tangential_coeffs.shape == batch_dims + (C, 2), tangential_coeffs.shape
11611161
if thin_prism_coeffs is not None:
1162-
assert thin_prism_coeffs.shape == batch_dims + (C, 2), thin_prism_coeffs.shape
1162+
assert thin_prism_coeffs.shape == batch_dims + (C, 4), thin_prism_coeffs.shape
11631163
if viewmats_rs is not None:
11641164
assert viewmats_rs.shape == batch_dims + (C, 4, 4), viewmats_rs.shape
11651165

@@ -1349,7 +1349,7 @@ def forward(
13491349
# distortion
13501350
radial_coeffs: Optional[Tensor] = None, # [..., C, 6] or [..., C, 4]
13511351
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
1352-
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 2]
1352+
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
13531353
# rolling shutter
13541354
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
13551355
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]

gsplat/cuda/csrc/Projection.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,7 @@ projection_ut_3dgs_fused(
930930
ShutterType rs_type,
931931
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
932932
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
933-
const at::optional<at::Tensor> thin_prism_coeffs // [..., C, 2] optional
933+
const at::optional<at::Tensor> thin_prism_coeffs // [..., C, 4] optional
934934
) {
935935
DEVICE_GUARD(means);
936936
CHECK_INPUT(means);

gsplat/cuda/csrc/Projection.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ void launch_projection_ut_3dgs_fused_kernel(
270270
ShutterType rs_type,
271271
const at::optional<at::Tensor> radial_coeffs, // [C, 6] or [C, 4] optional
272272
const at::optional<at::Tensor> tangential_coeffs, // [C, 2] optional
273-
const at::optional<at::Tensor> thin_prism_coeffs, // [C, 2] optional
273+
const at::optional<at::Tensor> thin_prism_coeffs, // [C, 4] optional
274274
// outputs
275275
at::Tensor radii, // [C, N, 2]
276276
at::Tensor means2d, // [C, N, 2]

gsplat/cuda/csrc/ProjectionUT3DGSFused.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ __global__ void projection_ut_3dgs_fused_kernel(
3737
const ShutterType rs_type,
3838
const scalar_t *__restrict__ radial_coeffs, // [B, C, 6] or [B, C, 4] optional
3939
const scalar_t *__restrict__ tangential_coeffs, // [B, C, 2] optional
40-
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 2] optional
40+
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 4] optional
4141
// outputs
4242
int32_t *__restrict__ radii, // [B, C, N, 2]
4343
scalar_t *__restrict__ means2d, // [B, C, N, 2]
@@ -110,7 +110,7 @@ __global__ void projection_ut_3dgs_fused_kernel(
110110
cm_params.tangential_coeffs = make_array<float, 2>(tangential_coeffs + bid * C * 2 + cid * 2);
111111
}
112112
if (thin_prism_coeffs != nullptr) {
113-
cm_params.thin_prism_coeffs = make_array<float, 4>(thin_prism_coeffs + bid * C * 2 + cid * 2);
113+
cm_params.thin_prism_coeffs = make_array<float, 4>(thin_prism_coeffs + bid * C * 4 + cid * 4);
114114
}
115115
OpenCVPinholeCameraModel camera_model(cm_params);
116116
image_gaussian_return =
@@ -226,7 +226,7 @@ void launch_projection_ut_3dgs_fused_kernel(
226226
ShutterType rs_type,
227227
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
228228
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
229-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
229+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
230230
// outputs
231231
at::Tensor radii, // [..., C, N, 2]
232232
at::Tensor means2d, // [..., C, N, 2]

gsplat/cuda/csrc/Rasterization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_from_world_3d
715715
ShutterType rs_type,
716716
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
717717
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
718-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
718+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
719719
// intersections
720720
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
721721
const at::Tensor flatten_ids // [n_isects]
@@ -840,7 +840,7 @@ rasterize_to_pixels_from_world_3dgs_bwd(
840840
ShutterType rs_type,
841841
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
842842
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
843-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
843+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
844844
// intersections
845845
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
846846
const at::Tensor flatten_ids, // [n_isects]

gsplat/cuda/csrc/Rasterization.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void launch_rasterize_to_pixels_from_world_3dgs_fwd_kernel(
220220
ShutterType rs_type,
221221
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
222222
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
223-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
223+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
224224
// intersections
225225
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
226226
const at::Tensor flatten_ids, // [n_isects]
@@ -254,7 +254,7 @@ void launch_rasterize_to_pixels_from_world_3dgs_bwd_kernel(
254254
ShutterType rs_type,
255255
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
256256
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
257-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
257+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
258258
// intersections
259259
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
260260
const at::Tensor flatten_ids, // [n_isects]

gsplat/cuda/csrc/RasterizeToPixelsFromWorld3DGSBwd.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ __global__ void rasterize_to_pixels_from_world_3dgs_bwd_kernel(
4343
const ShutterType rs_type,
4444
const scalar_t *__restrict__ radial_coeffs, // [B, C, 6] or [B, C, 4] optional
4545
const scalar_t *__restrict__ tangential_coeffs, // [B, C, 2] optional
46-
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 2] optional
46+
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 4] optional
4747
// intersections
4848
const int32_t *__restrict__ tile_offsets, // [B, C, tile_height, tile_width]
4949
const int32_t *__restrict__ flatten_ids, // [n_isects]
@@ -408,7 +408,7 @@ void launch_rasterize_to_pixels_from_world_3dgs_bwd_kernel(
408408
ShutterType rs_type,
409409
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
410410
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
411-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
411+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
412412
// intersections
413413
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
414414
const at::Tensor flatten_ids, // [n_isects]

gsplat/cuda/csrc/RasterizeToPixelsFromWorld3DGSFwd.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ __global__ void rasterize_to_pixels_from_world_3dgs_fwd_kernel(
4545
const ShutterType rs_type,
4646
const scalar_t *__restrict__ radial_coeffs, // [B, C, 6] or [B, C, 4] optional
4747
const scalar_t *__restrict__ tangential_coeffs, // [B, C, 2] optional
48-
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 2] optional
48+
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 4] optional
4949
// intersections
5050
const int32_t *__restrict__ tile_offsets, // [B, C, tile_height, tile_width]
5151
const int32_t *__restrict__ flatten_ids, // [n_isects]
@@ -309,7 +309,7 @@ void launch_rasterize_to_pixels_from_world_3dgs_fwd_kernel(
309309
ShutterType rs_type,
310310
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
311311
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
312-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
312+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
313313
// intersections
314314
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
315315
const at::Tensor flatten_ids, // [n_isects]

gsplat/cuda/include/Ops.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ projection_ut_3dgs_fused(
489489
ShutterType rs_type,
490490
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
491491
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
492-
const at::optional<at::Tensor> thin_prism_coeffs // [..., C, 2] optional
492+
const at::optional<at::Tensor> thin_prism_coeffs // [..., C, 4] optional
493493
);
494494

495495
std::tuple<at::Tensor, at::Tensor, at::Tensor>
@@ -517,7 +517,7 @@ rasterize_to_pixels_from_world_3dgs_fwd(
517517
ShutterType rs_type,
518518
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
519519
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
520-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
520+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
521521
// intersections
522522
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
523523
const at::Tensor flatten_ids // [n_isects]
@@ -548,7 +548,7 @@ rasterize_to_pixels_from_world_3dgs_bwd(
548548
ShutterType rs_type,
549549
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
550550
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
551-
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 2] optional
551+
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
552552
// intersections
553553
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
554554
const at::Tensor flatten_ids, // [n_isects]

0 commit comments

Comments
 (0)