Skip to content

Commit bd5c5d2

Browse files
authored
Merge pull request #1601 from ziyi-zhang/main
Radiance Field Loss / Many worlds
2 parents 0af094d + 12ce409 commit bd5c5d2

File tree

9 files changed

+299
-54
lines changed

9 files changed

+299
-54
lines changed

include/neural-graphics-primitives/common.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ using namespace tcnn;
3838

3939
namespace ngp {
4040

41+
// Training modes.
42+
// - NeRF: Standard volumetric reconstruction approach
43+
// - RFL (Radiance Field Loss): Promotes surface-like representations
44+
// - RFL-Relaxed: Hybrid approach that maintains NeRF-like volumetric properties while
45+
// encouraging surface formation, resulting in faster rendering
46+
// For technical details, see: https://rgl.epfl.ch/publications/Zhang2025Radiance
47+
enum class ETrainMode : int {
48+
Nerf,
49+
Rfl,
50+
RflRelax,
51+
};
52+
static constexpr const char* TrainModeStr = "Nerf\0Rfl\0RflRelax\0\0";
53+
4154
enum class EMeshRenderMode : int {
4255
Off,
4356
VertexColors,

include/neural-graphics-primitives/fused_kernels/render_nerf.cuh

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ __global__ void render_nerf(
5353
ENerfActivation density_activation,
5454
ENerfActivation rgb_activation,
5555
float min_transmittance,
56-
bool train_in_linear_colors
56+
bool train_in_linear_colors,
57+
bool surface_rendering,
58+
float surface_rendering_threshold
5759
) {
5860
uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
5961
uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -142,14 +144,21 @@ __global__ void render_nerf(
142144
// Composit color
143145
float alpha = 1.f - __expf(-network_to_density(nerf_out.w, density_activation) * dt);
144146
float weight = alpha * (1.0f - color.a);
145-
color += vec4(network_to_rgb_vec(nerf_out.xyz(), rgb_activation) * weight, weight);
147+
vec3 rgb = network_to_rgb_vec(nerf_out.xyz(), rgb_activation);
148+
color += vec4(rgb * weight, weight);
146149

147150
if (weight > max_weight) {
148151
max_weight = weight;
149152
best_depth_candidate = lens.is_360() ? distance(pos, cam_pos) : dot(cam_fwd, pos - cam_pos);
150153
}
151154

152-
if (color.a > (1.0f - min_transmittance)) {
155+
if (surface_rendering && alpha >= surface_rendering_threshold) {
156+
// Surface rendering: return the first surface point that has a sufficient occupancy
157+
color.rgb() = rgb;
158+
color.a = 1.0f;
159+
best_depth_candidate = lens.is_360() ? distance(pos, cam_pos) : dot(cam_fwd, pos - cam_pos);
160+
alive = false;
161+
} else if (color.a > (1.0f - min_transmittance)) {
153162
color /= color.a;
154163
alive = false;
155164
}

include/neural-graphics-primitives/fused_kernels/train_nerf.cuh

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ __global__ void train_nerf(
7676
const vec3* __restrict__ exposure,
7777
vec3* __restrict__ exposure_gradient,
7878
float depth_supervision_lambda,
79-
float near_distance
79+
float near_distance,
80+
81+
uint32_t training_step,
82+
ETrainMode training_mode,
83+
uint32_t rfl_warmup_steps
8084
) {
8185
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
8286

@@ -127,6 +131,39 @@ __global__ void train_nerf(
127131
const float startt = advance_n_steps(tminmax.x, cone_angle, random_val(rng));
128132
vec3 idir = vec3(1.0f) / ray.d;
129133

134+
if (train_with_random_bg_color) {
135+
background_color = random_val_3d(rng);
136+
}
137+
138+
vec3 pre_envmap_background_color = background_color = srgb_to_linear(background_color);
139+
140+
// Composit background behind envmap
141+
vec4 envmap_value;
142+
if (envmap) {
143+
envmap_value = read_envmap(envmap, ray.d);
144+
background_color = envmap_value.rgb() + background_color * (1.0f - envmap_value.a);
145+
}
146+
147+
vec3 exposure_scale = exp(0.6931471805599453f * exposure[img]);
148+
149+
// Need rgbtarget before the first pass in RFL training mode
150+
vec3 rgbtarget;
151+
if (train_in_linear_colors || color_space == EColorSpace::Linear) {
152+
rgbtarget = exposure_scale * texsamp.rgb() + (1.0f - texsamp.a) * background_color;
153+
154+
if (!train_in_linear_colors) {
155+
rgbtarget = linear_to_srgb(rgbtarget);
156+
background_color = linear_to_srgb(background_color);
157+
}
158+
} else if (color_space == EColorSpace::SRGB) {
159+
background_color = linear_to_srgb(background_color);
160+
if (texsamp.a > 0) {
161+
rgbtarget = linear_to_srgb(exposure_scale * texsamp.rgb() / texsamp.a) * texsamp.a + (1.0f - texsamp.a) * background_color;
162+
} else {
163+
rgbtarget = background_color;
164+
}
165+
}
166+
130167
// first pass to compute an accurate number of steps
131168
uint32_t j = 0;
132169
float t = startt;
@@ -135,6 +172,7 @@ __global__ void train_nerf(
135172

136173
vec4 color = vec4(0.0f);
137174
vec3 hitpoint = vec3(0.0f);
175+
vec3 loss_bg = vec3(0.0f);
138176

139177
bool alive = valid;
140178

@@ -180,8 +218,10 @@ __global__ void train_nerf(
180218
// Composit color
181219
float alpha = 1.f - __expf(-network_to_density(nerf_out.w, density_activation) * dt);
182220
float weight = alpha * (1.0f - color.a);
183-
color += vec4(network_to_rgb_vec(nerf_out.xyz(), rgb_activation) * weight, weight);
221+
vec3 rgb = network_to_rgb_vec(nerf_out.rgb(), rgb_activation);
222+
color += vec4(rgb * weight, weight);
184223

224+
loss_bg += weight * loss_and_gradient(rgbtarget, rgb, loss_type).loss;
185225
hitpoint += weight * pos;
186226

187227
if (1.0f - color.a < EPSILON || j >= NERF_STEPS()) {
@@ -209,40 +249,9 @@ __global__ void train_nerf(
209249
numsteps_out[ray_idx*2+1] = base;
210250
}
211251

212-
if (train_with_random_bg_color) {
213-
background_color = random_val_3d(rng);
214-
}
215-
216-
vec3 pre_envmap_background_color = background_color = srgb_to_linear(background_color);
217-
218-
// Composit background behind envmap
219-
vec4 envmap_value;
220-
if (envmap) {
221-
envmap_value = read_envmap(envmap, ray.d);
222-
background_color = envmap_value.rgb() + background_color * (1.0f - envmap_value.a);
223-
}
224-
225-
vec3 exposure_scale = exp(0.6931471805599453f * exposure[img]);
226-
227-
vec3 rgbtarget;
228-
if (train_in_linear_colors || color_space == EColorSpace::Linear) {
229-
rgbtarget = exposure_scale * texsamp.rgb() + (1.0f - texsamp.a) * background_color;
230-
231-
if (!train_in_linear_colors) {
232-
rgbtarget = linear_to_srgb(rgbtarget);
233-
background_color = linear_to_srgb(background_color);
234-
}
235-
} else if (color_space == EColorSpace::SRGB) {
236-
background_color = linear_to_srgb(background_color);
237-
if (texsamp.a > 0) {
238-
rgbtarget = linear_to_srgb(exposure_scale * texsamp.rgb() / texsamp.a) * texsamp.a + (1.0f - texsamp.a) * background_color;
239-
} else {
240-
rgbtarget = background_color;
241-
}
242-
}
243-
244252
if (1.0f - color.a >= EPSILON) {
245253
color.rgb() += (1.0f - color.a) * background_color;
254+
loss_bg += (1.0f - color.a) * loss_and_gradient(rgbtarget, background_color, loss_type).loss;
246255
}
247256

248257
// Step again, this time computing loss
@@ -299,6 +308,7 @@ __global__ void train_nerf(
299308

300309
// now do it again computing gradients
301310
vec4 color2 = vec4(0.0f);
311+
vec3 loss_bg2 = vec3(0.0f);
302312
float depth2 = 0.0f;
303313
t = startt;
304314
j = 0;
@@ -366,7 +376,42 @@ __global__ void train_nerf(
366376

367377
// we know the suffix of this ray compared to where we are up to. note the suffix depends on this step's alpha as suffix = (1-alpha)*(somecolor), so dsuffix/dalpha = -somecolor = -suffix/(1-alpha)
368378
const vec3 suffix = color.rgb() - color2.rgb();
369-
const vec3 dloss_by_drgb = weight * lg.gradient;
379+
380+
float density_derivative = network_to_density_derivative(float(local_network_output[3]), density_activation);
381+
const float depth_suffix = depth - depth2;
382+
const float depth_supervision = depth_loss_gradient * (T * local_depth - depth_suffix);
383+
384+
vec3 dloss_by_drgb;
385+
float dloss_by_dmlp;
386+
if (training_mode == ETrainMode::Rfl && training_step < rfl_warmup_steps) {
387+
training_mode = ETrainMode::Nerf; // Warm up training
388+
}
389+
if (training_mode == ETrainMode::Rfl) {
390+
// Radiance field loss
391+
LossAndGradient local_lg = loss_and_gradient(rgbtarget, rgb, loss_type);
392+
loss_bg2 += weight * local_lg.loss;
393+
dloss_by_drgb = weight * local_lg.gradient;
394+
dloss_by_dmlp = density_derivative * (
395+
dt * sum(T * local_lg.loss - (loss_bg - loss_bg2) + depth_supervision)
396+
);
397+
} else if (training_mode == ETrainMode::RflRelax) {
398+
// In-between volume reconstruction and surface reconstruction.
399+
// This is different from the relaxation in the paper, but is much simpler and also promotes surfaces.
400+
const vec3 rgb_bg = suffix / fmaxf(1e-6f, T);
401+
const vec3 rgb_lerp = (1 - alpha) * rgb_bg + alpha * rgb;
402+
LossAndGradient local_lg = loss_and_gradient(rgbtarget, rgb_lerp, loss_type);
403+
404+
dloss_by_drgb = weight * local_lg.gradient;
405+
dloss_by_dmlp = density_derivative * (
406+
dt * (dot(local_lg.gradient, T * rgb - suffix) + depth_supervision)
407+
);
408+
} else {
409+
// The original NeRF loss
410+
dloss_by_drgb = weight * lg.gradient;
411+
dloss_by_dmlp = density_derivative * (
412+
dt * (dot(lg.gradient, T * rgb - suffix) + depth_supervision)
413+
);
414+
}
370415

371416
tvec<network_precision_t, 4> local_dL_doutput;
372417

@@ -375,14 +420,6 @@ __global__ void train_nerf(
375420
local_dL_doutput[1] = loss_scale * (dloss_by_drgb.y * network_to_rgb_derivative(local_network_output[1], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[1]));
376421
local_dL_doutput[2] = loss_scale * (dloss_by_drgb.z * network_to_rgb_derivative(local_network_output[2], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[2]));
377422

378-
float density_derivative = network_to_density_derivative(float(local_network_output[3]), density_activation);
379-
const float depth_suffix = depth - depth2;
380-
const float depth_supervision = depth_loss_gradient * (T * local_depth - depth_suffix);
381-
382-
float dloss_by_dmlp = density_derivative * (
383-
dt * (dot(lg.gradient, T * rgb - suffix) + depth_supervision)
384-
);
385-
386423
//static constexpr float mask_supervision_strength = 1.f; // we are already 'leaking' mask information into the nerf via the random bg colors; setting this to eg between 1 and 100 encourages density towards 0 in such regions.
387424
//dloss_by_dmlp += (texsamp.a<0.001f) ? mask_supervision_strength * weight : 0.f;
388425

include/neural-graphics-primitives/testbed.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,9 @@ class Testbed {
819819
default_rng_t density_grid_rng;
820820
int view = 0;
821821

822+
ETrainMode train_mode = ETrainMode::RflRelax;
823+
int rfl_warmup_steps = 1000;
824+
822825
float depth_supervision_lambda = 0.f;
823826

824827
GPUMemory<float> sharpness_grid;
@@ -880,6 +883,9 @@ class Testbed {
880883

881884
float cone_angle_constant = 1.f / 256.f;
882885

886+
bool surface_rendering = false;
887+
float surface_rendering_threshold = 0.5f;
888+
883889
bool visualize_cameras = false;
884890

885891
float render_min_transmittance = 0.01f;

scripts/run.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def parse_args():
4040
parser.add_argument("--test_transforms", default="", help="Path to a nerf style transforms json from which we will compute PSNR.")
4141
parser.add_argument("--near_distance", default=-1, type=float, help="Set the distance from the camera at which training rays start for nerf. <0 means use ngp default")
4242
parser.add_argument("--exposure", default=0.0, type=float, help="Controls the brightness of the image. Positive numbers increase brightness, negative numbers decrease it.")
43+
parser.add_argument("--train_mode", default="", type=str, help="The training mode to use. Can be 'nerf', 'rfl', 'rfl_relax'. If not specified, the default mode will be used.")
44+
parser.add_argument("--rfl_warmup_steps", type=int, default=1000, help="Number of steps to train in NeRF mode before switching to RFL mode. Default is 1000. Only used if --train_mode is set to 'rfl'.")
45+
parser.add_argument("--no_rflrelax_training_schedule", action="store_true", help="Disable RFL training schedule for RflRelax mode (active between steps 15k-30k).")
4346

4447
parser.add_argument("--screenshot_transforms", default="", help="Path to a nerf style transforms.json from which to save screenshots.")
4548
parser.add_argument("--screenshot_frames", nargs="*", help="Which frame(s) to take screenshots of.")
@@ -146,6 +149,17 @@ def get_scene(scene):
146149
print("NeRF training ray near_distance ", args.near_distance)
147150
testbed.nerf.training.near_distance = args.near_distance
148151

152+
if args.train_mode:
153+
if args.train_mode.lower() == "nerf":
154+
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
155+
elif args.train_mode.lower() == "rfl":
156+
testbed.nerf.training.train_mode = ngp.TrainMode.Rfl
157+
elif args.train_mode.lower() == "rfl_relax" or args.train_mode.lower() == "rflrelax":
158+
testbed.nerf.training.train_mode = ngp.TrainMode.RflRelax
159+
else:
160+
raise ValueError(f"Unknown train mode: {args.train_mode}")
161+
testbed.nerf.training.rfl_warmup_steps = args.rfl_warmup_steps
162+
149163
if args.nerf_compatibility:
150164
print(f"NeRF compatibility mode enabled")
151165

@@ -167,6 +181,9 @@ def get_scene(scene):
167181
# Match nerf paper behaviour and train on a fixed bg.
168182
testbed.nerf.training.random_bg_color = False
169183

184+
# Ensure that the training mode is set to NeRF.
185+
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
186+
170187
old_training_step = 0
171188
n_steps = args.n_steps
172189

@@ -176,6 +193,7 @@ def get_scene(scene):
176193
if n_steps < 0 and (not args.load_snapshot or args.gui):
177194
n_steps = 35000
178195

196+
original_train_mode = ngp.TrainMode(testbed.nerf.training.train_mode)
179197
tqdm_last_update = 0
180198
if n_steps > 0:
181199
with tqdm(desc="Training", total=n_steps, unit="steps") as t:
@@ -194,6 +212,16 @@ def get_scene(scene):
194212
old_training_step = 0
195213
t.reset()
196214

215+
# Rfl-relax training schedule
216+
progress_fraction = float(testbed.training_step) / n_steps
217+
if (original_train_mode == ngp.TrainMode.RflRelax and
218+
not args.no_rflrelax_training_schedule):
219+
# By default only enable RflRelax mode between 15k and 30k steps
220+
if 3/7 <= progress_fraction < 6/7:
221+
testbed.nerf.training.train_mode = ngp.TrainMode.RflRelax
222+
else:
223+
testbed.nerf.training.train_mode = ngp.TrainMode.Nerf
224+
197225
now = time.monotonic()
198226
if now - tqdm_last_update > 0.1:
199227
t.update(testbed.training_step - old_training_step)

src/nerf_loader.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,10 @@ NerfDataset load_nerf(const std::vector<fs::path>& jsonpaths, float sharpen_amou
441441
result.from_mitsuba = true;
442442
}
443443

444+
if (json.contains("from_mitsuba")) {
445+
result.from_mitsuba = bool(json["from_mitsuba"]);
446+
}
447+
444448
if (json.contains("fix_premult")) {
445449
fix_premult = (bool)json["fix_premult"];
446450
}

src/python_api.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,12 @@ PYBIND11_MODULE(pyngp, m) {
308308

309309
m.def("free_temporary_memory", &free_all_gpu_memory_arenas);
310310

311+
py::enum_<ETrainMode>(m, "TrainMode")
312+
.value("Nerf", ETrainMode::Nerf)
313+
.value("Rfl", ETrainMode::Rfl)
314+
.value("RflRelax", ETrainMode::RflRelax)
315+
.export_values();
316+
311317
py::enum_<ETestbedMode>(m, "TestbedMode")
312318
.value("Nerf", ETestbedMode::Nerf)
313319
.value("Sdf", ETestbedMode::Sdf)
@@ -793,6 +799,8 @@ PYBIND11_MODULE(pyngp, m) {
793799
//.def_readonly("focal_lengths", &Testbed::Nerf::Training::focal_lengths) // use training.dataset.metadata instead
794800
.def_readwrite("near_distance", &Testbed::Nerf::Training::near_distance)
795801
.def_readwrite("density_grid_decay", &Testbed::Nerf::Training::density_grid_decay)
802+
.def_readwrite("train_mode", &Testbed::Nerf::Training::train_mode)
803+
.def_readwrite("rfl_warmup_steps", &Testbed::Nerf::Training::rfl_warmup_steps)
796804
.def_readwrite("extrinsic_l2_reg", &Testbed::Nerf::Training::extrinsic_l2_reg)
797805
.def_readwrite("extrinsic_learning_rate", &Testbed::Nerf::Training::extrinsic_learning_rate)
798806
.def_readwrite("intrinsic_l2_reg", &Testbed::Nerf::Training::intrinsic_l2_reg)

0 commit comments

Comments
 (0)