Skip to content

Commit 0c305ab

Browse files
authored
Update forward & backward for rendered alpha image (#70)
* finish alpha forward & backward * black format * fix some merging issues * remove unnecessary .cuda() * add return_alpha keyword * add some notes * black reformat
1 parent e8696bd commit 0c305ab

File tree

6 files changed

+37
-9
lines changed

6 files changed

+37
-9
lines changed

examples/test_rasterize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
155155
frames = []
156156
for i in range(iterations):
157157
optimizer.zero_grad()
158-
slow_out = self.forward_slow()
158+
slow_out, _ = self.forward_slow()
159159

160160
loss = mse_loss(slow_out, self.gt_image)
161161
loss.backward()
@@ -168,7 +168,7 @@ def train(self, iterations: int = 1000, lr: float = 0.01, save_imgs: bool = True
168168
]
169169

170170
optimizer.zero_grad()
171-
new_out = self.forward_new()
171+
new_out, _ = self.forward_new()
172172
loss = mse_loss(new_out, self.gt_image)
173173
loss.backward()
174174

gsplat/cuda/csrc/backward.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ __global__ void nd_rasterize_backward_kernel(
1818
const float* __restrict__ final_Ts,
1919
const int* __restrict__ final_index,
2020
const float* __restrict__ v_output,
21+
const float* __restrict__ v_output_alpha,
2122
float2* __restrict__ v_xy,
2223
float3* __restrict__ v_conic,
2324
float* __restrict__ v_rgb,
@@ -45,6 +46,7 @@ __global__ void nd_rasterize_backward_kernel(
4546
int2 range = tile_bins[tile_id];
4647
// df/d_out for this pixel
4748
const float *v_out = &(v_output[channels * pix_id]);
49+
const float v_out_alpha = v_output_alpha[pix_id];
4850
// this is the T AFTER the last gaussian in this pixel
4951
float T_final = final_Ts[pix_id];
5052
float T = T_final;
@@ -97,7 +99,7 @@ __global__ void nd_rasterize_backward_kernel(
9799
// update the running sum
98100
S[c] += rgbs[channels * g + c] * fac;
99101
}
100-
102+
v_alpha += T_final * ra * v_out_alpha;
101103
// update v_opacity for this gaussian
102104
atomicAdd(&(v_opacity[g]), vis * v_alpha);
103105

@@ -146,6 +148,7 @@ __global__ void rasterize_backward_kernel(
146148
const float* __restrict__ final_Ts,
147149
const int* __restrict__ final_index,
148150
const float3* __restrict__ v_output,
151+
const float* __restrict__ v_output_alpha,
149152
float2* __restrict__ v_xy,
150153
float3* __restrict__ v_conic,
151154
float3* __restrict__ v_rgb,
@@ -188,6 +191,7 @@ __global__ void rasterize_backward_kernel(
188191

189192
// df/d_out for this pixel
190193
const float3 v_out = v_output[pix_id];
194+
const float v_out_alpha = v_output_alpha[pix_id];
191195

192196
// collect and process batches of gaussians
193197
// each thread loads one gaussian at a time before rasterizing
@@ -265,6 +269,8 @@ __global__ void rasterize_backward_kernel(
265269
v_alpha += (rgb.x * T - buffer.x * ra) * v_out.x;
266270
v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y;
267271
v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z;
272+
273+
v_alpha += T_final * ra * v_out_alpha;
268274
// contribution from background pixel
269275
v_alpha += -T_final * ra * background.x * v_out.x;
270276
v_alpha += -T_final * ra * background.y * v_out.y;

gsplat/cuda/csrc/backward.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ __global__ void nd_rasterize_backward_kernel(
4343
const float* __restrict__ final_Ts,
4444
const int* __restrict__ final_index,
4545
const float* __restrict__ v_output,
46+
const float* __restrict__ v_output_alpha,
4647
float2* __restrict__ v_xy,
4748
float3* __restrict__ v_conic,
4849
float* __restrict__ v_rgb,
@@ -63,6 +64,7 @@ __global__ void rasterize_backward_kernel(
6364
const float* __restrict__ final_Ts,
6465
const int* __restrict__ final_index,
6566
const float3* __restrict__ v_output,
67+
const float* __restrict__ v_output_alpha,
6668
float2* __restrict__ v_xy,
6769
float3* __restrict__ v_conic,
6870
float3* __restrict__ v_rgb,

gsplat/cuda/csrc/bindings.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,8 @@ std::
486486
const torch::Tensor &background,
487487
const torch::Tensor &final_Ts,
488488
const torch::Tensor &final_idx,
489-
const torch::Tensor &v_output // dL_dout_color
489+
const torch::Tensor &v_output, // dL_dout_color
490+
const torch::Tensor &v_output_alpha // dL_dout_alpha
490491
) {
491492

492493
CHECK_INPUT(xys);
@@ -540,6 +541,7 @@ std::
540541
final_Ts.contiguous().data_ptr<float>(),
541542
final_idx.contiguous().data_ptr<int>(),
542543
v_output.contiguous().data_ptr<float>(),
544+
v_output_alpha.contiguous().data_ptr<float>(),
543545
(float2 *)v_xy.contiguous().data_ptr<float>(),
544546
(float3 *)v_conic.contiguous().data_ptr<float>(),
545547
v_colors.contiguous().data_ptr<float>(),
@@ -569,7 +571,8 @@ std::
569571
const torch::Tensor &background,
570572
const torch::Tensor &final_Ts,
571573
const torch::Tensor &final_idx,
572-
const torch::Tensor &v_output // dL_dout_color
574+
const torch::Tensor &v_output, // dL_dout_color
575+
const torch::Tensor &v_output_alpha // dL_dout_alpha
573576
) {
574577

575578
CHECK_INPUT(xys);
@@ -612,6 +615,7 @@ std::
612615
final_Ts.contiguous().data_ptr<float>(),
613616
final_idx.contiguous().data_ptr<int>(),
614617
(float3 *)v_output.contiguous().data_ptr<float>(),
618+
v_output_alpha.contiguous().data_ptr<float>(),
615619
(float2 *)v_xy.contiguous().data_ptr<float>(),
616620
(float3 *)v_conic.contiguous().data_ptr<float>(),
617621
(float3 *)v_colors.contiguous().data_ptr<float>(),

gsplat/cuda/csrc/bindings.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ std::
157157
const torch::Tensor &background,
158158
const torch::Tensor &final_Ts,
159159
const torch::Tensor &final_idx,
160-
const torch::Tensor &v_output // dL_dout_color
160+
const torch::Tensor &v_output, // dL_dout_color
161+
const torch::Tensor &v_output_alpha
161162
);
162163

163164
std::
@@ -179,5 +180,6 @@ std::
179180
const torch::Tensor &background,
180181
const torch::Tensor &final_Ts,
181182
const torch::Tensor &final_idx,
182-
const torch::Tensor &v_output // dL_dout_color
183+
const torch::Tensor &v_output, // dL_dout_color
184+
const torch::Tensor &v_output_alpha
183185
);

gsplat/rasterize.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def rasterize_gaussians(
2222
img_height: int,
2323
img_width: int,
2424
background: Optional[Float[Tensor, "channels"]] = None,
25+
return_alpha: Optional[bool] = False,
2526
) -> Tensor:
2627
"""Rasterizes 2D gaussians by sorting and binning gaussian intersections for each tile and returns an N-dimensional output using alpha-compositing.
2728
@@ -39,11 +40,13 @@ def rasterize_gaussians(
3940
img_height (int): height of the rendered image.
4041
img_width (int): width of the rendered image.
4142
background (Tensor): background color
43+
return_alpha (bool): whether to return alpha channel
4244
4345
Returns:
4446
A Tensor:
4547
4648
- **out_img** (Tensor): N-dimensional rendered output image.
49+
- **out_alpha** (Optional[Tensor]): Alpha channel of the rendered output image.
4750
"""
4851
if colors.dtype == torch.uint8:
4952
# make sure colors are float [0,1]
@@ -75,6 +78,7 @@ def rasterize_gaussians(
7578
img_height,
7679
img_width,
7780
background.contiguous(),
81+
return_alpha,
7882
)
7983

8084

@@ -94,6 +98,7 @@ def forward(
9498
img_height: int,
9599
img_width: int,
96100
background: Optional[Float[Tensor, "channels"]] = None,
101+
return_alpha: Optional[bool] = False,
97102
) -> Tensor:
98103
num_points = xys.size(0)
99104
BLOCK_X, BLOCK_Y = 16, 16
@@ -148,13 +153,20 @@ def forward(
148153
final_idx,
149154
)
150155

151-
return out_img
156+
if return_alpha:
157+
out_alpha = 1 - final_Ts
158+
return out_img, out_alpha
159+
else:
160+
return out_img
152161

153162
@staticmethod
154-
def backward(ctx, v_out_img):
163+
def backward(ctx, v_out_img, v_out_alpha=None):
155164
img_height = ctx.img_height
156165
img_width = ctx.img_width
157166

167+
if v_out_alpha is None:
168+
v_out_alpha = torch.zeros_like(v_out_img[..., 0])
169+
158170
(
159171
gaussian_ids_sorted,
160172
tile_bins,
@@ -184,6 +196,7 @@ def backward(ctx, v_out_img):
184196
final_Ts,
185197
final_idx,
186198
v_out_img,
199+
v_out_alpha,
187200
)
188201

189202
return (
@@ -197,4 +210,5 @@ def backward(ctx, v_out_img):
197210
None, # img_height
198211
None, # img_width
199212
None, # background
213+
None, # return_alpha
200214
)

0 commit comments

Comments
 (0)