@@ -18,6 +18,7 @@ __global__ void nd_rasterize_backward_kernel(
1818 const float * __restrict__ final_Ts,
1919 const int * __restrict__ final_index,
2020 const float * __restrict__ v_output,
21+ const float * __restrict__ v_output_alpha,
2122 float2 * __restrict__ v_xy,
2223 float3 * __restrict__ v_conic,
2324 float * __restrict__ v_rgb,
@@ -45,6 +46,7 @@ __global__ void nd_rasterize_backward_kernel(
4546 int2 range = tile_bins[tile_id];
4647 // df/d_out for this pixel
4748 const float *v_out = &(v_output[channels * pix_id]);
49+ const float v_out_alpha = v_output_alpha[pix_id];
4850 // this is the T AFTER the last gaussian in this pixel
4951 float T_final = final_Ts[pix_id];
5052 float T = T_final;
@@ -97,7 +99,7 @@ __global__ void nd_rasterize_backward_kernel(
9799 // update the running sum
98100 S[c] += rgbs[channels * g + c] * fac;
99101 }
100-
102+ v_alpha += T_final * ra * v_out_alpha;
101103 // update v_opacity for this gaussian
102104 atomicAdd (&(v_opacity[g]), vis * v_alpha);
103105
@@ -146,6 +148,7 @@ __global__ void rasterize_backward_kernel(
146148 const float * __restrict__ final_Ts,
147149 const int * __restrict__ final_index,
148150 const float3 * __restrict__ v_output,
151+ const float * __restrict__ v_output_alpha,
149152 float2 * __restrict__ v_xy,
150153 float3 * __restrict__ v_conic,
151154 float3 * __restrict__ v_rgb,
@@ -188,6 +191,7 @@ __global__ void rasterize_backward_kernel(
188191
189192 // df/d_out for this pixel
190193 const float3 v_out = v_output[pix_id];
194+ const float v_out_alpha = v_output_alpha[pix_id];
191195
192196 // collect and process batches of gaussians
193197 // each thread loads one gaussian at a time before rasterizing
@@ -265,6 +269,8 @@ __global__ void rasterize_backward_kernel(
265269 v_alpha += (rgb.x * T - buffer.x * ra) * v_out.x ;
266270 v_alpha += (rgb.y * T - buffer.y * ra) * v_out.y ;
267271 v_alpha += (rgb.z * T - buffer.z * ra) * v_out.z ;
272+
273+ v_alpha += T_final * ra * v_out_alpha;
268274 // contribution from background pixel
269275 v_alpha += -T_final * ra * background.x * v_out.x ;
270276 v_alpha += -T_final * ra * background.y * v_out.y ;
0 commit comments