diff --git a/gradio_app.py b/gradio_app.py index c172c540..6b9d3fd7 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -26,6 +26,7 @@ ### training options parser.add_argument('--iters', type=int, default=10000, help="training iters") parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate") +parser.add_argument('--lr2', type=float, default=1e-3, help="initial learning rate 2") parser.add_argument('--ckpt', type=str, default='latest') parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)") @@ -166,11 +167,20 @@ def submit(text, iters, seed): if opt.optim == 'adan': from optimizer import Adan # Adan usually requires a larger LR - optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-15) + if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf': + optimizer = lambda model: Adan(model.get_params(5 * opt.lr, 5 * opt.lr2), eps=1e-15) + else: + optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-15) elif opt.optim == 'adamw': - optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf': + optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + else: + optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr2), betas=(0.9, 0.99), eps=1e-15) else: # adam - optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf': + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr, opt.lr2), betas=(0.9, 0.99), eps=1e-15) + else: + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed diff --git a/main.py b/main.py index 9b93b20e..bd644ade 100644 --- a/main.py +++ b/main.py @@ -29,6 +29,7 @@ ### training options parser.add_argument('--iters', type=int, default=10000, help="training iters") parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate") + parser.add_argument('--lr2', type=float, default=1e-3, help="max learning rate 2") parser.add_argument('--warm_iters', type=int, default=500, help="training iters") parser.add_argument('--min_lr', type=float, default=1e-4, help="minimal learning rate") parser.add_argument('--ckpt', type=str, default='latest') @@ -49,7 +50,7 @@ parser.add_argument('--blob_density', type=float, default=10, help="max (center) density for the density blob") parser.add_argument('--blob_radius', type=float, default=0.5, help="control the radius for the density blob") # network backbone - parser.add_argument('--backbone', type=str, default='grid', choices=['grid', 'vanilla', 'grid_taichi'], help="nerf backbone") + parser.add_argument('--backbone', type=str, default='grid', choices=['grid', 'vanilla', 'grid_taichi', 'tensoRF', 'dnerf'], help="nerf backbone") parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer") parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version") parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key") @@ -103,6 +104,11 @@ from nerf.network import NeRFNetwork elif opt.backbone == 'grid': from nerf.network_grid import NeRFNetwork + elif opt.backbone == 'tensoRF': + from nerf.network_tensorf import NeRFNetwork + elif opt.backbone == 'dnerf': + opt.cuda_ray = False + from nerf.network_dnerf import NeRFNetwork elif opt.backbone == 'grid_taichi': opt.cuda_ray = False opt.taichi_ray = True @@ -151,9 +157,15 @@ if opt.optim == 'adan': from optimizer import Adan # Adan usually requires a larger LR - optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) + if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf': + optimizer = lambda model: Adan(model.get_params(5 * opt.lr, 5 * opt.lr2), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) + else: + optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) else: # adam - optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf': + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr, opt.lr2), betas=(0.9, 0.99), eps=1e-15) + else: + optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) if opt.backbone == 'vanilla': warm_up_with_cosine_lr = lambda iter: iter / opt.warm_iters if iter <= opt.warm_iters \ diff --git a/nerf/network.py b/nerf/network.py index 5676fe08..567b164f 100644 --- a/nerf/network.py +++ b/nerf/network.py @@ -217,7 +217,7 @@ def density(self, x): } - def background(self, d): + def background(self, d, x): h = self.encoder_bg(d) # [N, C] diff --git a/nerf/network_dnerf.py b/nerf/network_dnerf.py new file mode 100644 index 00000000..7f14c3a1 --- /dev/null +++ b/nerf/network_dnerf.py @@ -0,0 +1,308 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp +from .renderer import NeRFRenderer + +from encoding import get_encoder + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + encoding="tiledgrid", + encoding_dir="sphere_harmonics", + encoding_time="frequency", + encoding_deform="frequency", # "hashgrid" seems worse + encoding_bg="hashgrid", + num_layers=2, + hidden_dim=64, + geo_feat_dim=15, + num_layers_color=3, + hidden_dim_color=64, + num_layers_bg=2, + hidden_dim_bg=64, + num_layers_deform=5, # a deeper MLP is very necessary for performance. + hidden_dim_deform=128 + ): + + super().__init__(opt) + print('dnerf') + self.bound = opt.bound + + # deformation network + self.num_layers_deform = num_layers_deform + self.hidden_dim_deform = hidden_dim_deform + self.encoder_deform, self.in_dim_deform = get_encoder(encoding_deform, multires=10) + self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6) + self.density_scale = 1 + # time stamps for density grid + self.time_size = 64 + self.times = ((torch.arange(self.time_size, dtype=torch.float32) + 0.5) / self.time_size).view(-1, 1, 1) # [T, 1, 1] + self.density_grid = torch.zeros(self.time_size, self.cascade, self.grid_size ** 3) # [T, CAS, H * H * H] + + deform_net = [] + for l in range(num_layers_deform): + if l == 0: + in_dim = self.in_dim_deform + self.in_dim_time # grid dim + time + else: + in_dim = hidden_dim_deform + + if l == num_layers_deform - 1: + out_dim = 3 # deformation for xyz + else: + out_dim = hidden_dim_deform + + deform_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.deform_net = nn.ModuleList(deform_net) + + + # sigma network + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.geo_feat_dim = geo_feat_dim + self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * self.bound) + + sigma_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + self.in_dim_time + self.in_dim_deform # concat everything + else: + in_dim = hidden_dim + + if l == num_layers - 1: + out_dim = 1 + self.geo_feat_dim # 1 sigma + features for color + else: + out_dim = hidden_dim + + sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.sigma_net = nn.ModuleList(sigma_net) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir) + + color_net = [] + for l in range(num_layers_color): + if l == 0: + in_dim = self.in_dim_dir + self.geo_feat_dim + else: + in_dim = hidden_dim_color + + if l == num_layers_color - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_color + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background network + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = self.in_dim_bg + self.in_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) # nn.ModuleList(bg_net) + else: + self.bg_net = None + + print('dnerf init done') + + def forward(self, x, d, t=None, l=None, ratio=1, shading='albedo'): + print('forward 1') + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + # t: [1, 1], in [0, 1] + + # deform + enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C] + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + if enc_t.shape[0] == 1: + enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] + + deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C'] + for l in range(self.num_layers_deform): + deform = self.deform_net[l](deform) + if l != self.num_layers_deform - 1: + deform = F.relu(deform, inplace=True) + + x = x + deform + + # sigma + x = self.encoder(x, bound=self.bound) + h = torch.cat([x, enc_ori_x, enc_t], dim=1) + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + # color + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + print('foward return') + return sigma, rgbs, deform + + def density(self, x, t=None): + print('density 1') + # x: [N, 3], in [-bound, bound] + # t: [1, 1], in [0, 1] + + results = {} + + # deformation + enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C] + enc_t = self.encoder_time(t) # [1, 1] --> [1, C'] + if enc_t.shape[0] == 1: + enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C'] + + deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C'] + for l in range(self.num_layers_deform): + deform = self.deform_net[l](deform) + if l != self.num_layers_deform - 1: + deform = F.relu(deform, inplace=True) + + x = x + deform + results['deform'] = deform + + # sigma + x = self.encoder(x, bound=self.bound) + h = torch.cat([x, enc_ori_x, enc_t], dim=1) + for l in range(self.num_layers): + h = self.sigma_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + #sigma = F.relu(h[..., 0]) + sigma = trunc_exp(h[..., 0]) + geo_feat = h[..., 1:] + + results['sigma'] = sigma + results['geo_feat'] = geo_feat + + print('density return') + return results + + def background(self, d, x): + print('background 1') + # x: [N, 2], in [-1, 1] + + h = self.encoder_bg(x) # [N, C] + d = self.encoder_dir(d) + + h = torch.cat([d, h], dim=-1) + for l in range(self.num_layers_bg): + h = self.bg_net[l](h) + if l != self.num_layers_bg - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + print('background return') + return rgbs + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + print('color 1') + # x: [N, 3] in [-bound, bound] + # t: [1, 1], in [0, 1] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h + + print('color return') + return rgbs + + # optimizer utils + def get_params(self, lr, lr_net): + print('get_params 1') + + params = [ + {'params': self.encoder.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr_net}, + {'params': self.encoder_deform.parameters(), 'lr': lr}, + {'params': self.encoder_time.parameters(), 'lr': lr}, + {'params': self.deform_net.parameters(), 'lr': lr_net}, + ] + if self.bg_radius > 0: + params.append({'params': self.encoder_bg.parameters(), 'lr': lr}) + params.append({'params': self.bg_net.parameters(), 'lr': lr_net}) + + print('get_params return') + return params \ No newline at end of file diff --git a/nerf/network_grid.py b/nerf/network_grid.py index ac34bde1..6fe0d5d0 100644 --- a/nerf/network_grid.py +++ b/nerf/network_grid.py @@ -147,7 +147,7 @@ def density(self, x): } - def background(self, d): + def background(self, d, x): h = self.encoder_bg(d) # [N, C] diff --git a/nerf/network_grid_taichi.py b/nerf/network_grid_taichi.py index 0a9afe58..edd94ab6 100644 --- a/nerf/network_grid_taichi.py +++ b/nerf/network_grid_taichi.py @@ -145,7 +145,7 @@ def density(self, x): } - def background(self, d): + def background(self, d, x): h = self.encoder_bg(d) # [N, C] diff --git a/nerf/network_tensorf.py b/nerf/network_tensorf.py new file mode 100644 index 00000000..9b9af822 --- /dev/null +++ b/nerf/network_tensorf.py @@ -0,0 +1,358 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp +from .renderer import NeRFRenderer + +import numpy as np +from encoding import get_encoder +import raymarching + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + resolution=[128] * 3, + sigma_rank=[16] * 3, + color_rank=[48] * 3, + bg_resolution=[512, 512], + bg_rank=8, + color_feat_dim=27, + num_layers=3, + hidden_dim=128, + num_layers_bg=2, + hidden_dim_bg=64, + ): + + super().__init__(opt) + self.resolution = resolution + self.bound = opt.bound + + # vector-matrix decomposition + self.sigma_rank = sigma_rank + self.color_rank = color_rank + self.color_feat_dim = color_feat_dim + + self.mat_ids = [[0, 1], [0, 2], [1, 2]] + self.vec_ids = [2, 1, 0] + + self.sigma_mat, self.sigma_vec = self.init_one_svd(self.sigma_rank, self.resolution) + self.color_mat, self.color_vec = self.init_one_svd(self.color_rank, self.resolution) + self.basis_mat = nn.Linear(sum(self.color_rank), self.color_feat_dim, bias=False) + + # render module (default to freq feat + freq dir) + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder, enc_dim = get_encoder('frequency', input_dim=color_feat_dim, multires=2) + self.encoder_dir, enc_dim_dir = get_encoder('frequency', input_dim=3, multires=2) + + self.in_dim = enc_dim + enc_dim_dir + + color_net = [] + for l in range(num_layers): + if l == 0: + in_dim = self.in_dim + else: + in_dim = self.hidden_dim + + if l == num_layers - 1: + out_dim = 3 # rgb + else: + out_dim = self.hidden_dim + + color_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.color_net = nn.ModuleList(color_net) + + # background model + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + + # TODO: just use a matrix to model the background, no need of factorization. + #self.encoder_bg, self.in_dim_bg = get_encoder('hashgrid', input_dim=2, num_levels=4, log2_hashmap_size=18) # much smaller hashgrid + self.bg_resolution = bg_resolution + self.bg_rank = bg_rank + self.bg_mat = nn.Parameter(0.1 * torch.randn((1, bg_rank, bg_resolution[0], bg_resolution[1]))) # [1, R, H, W] + + bg_net = [] + for l in range(num_layers_bg): + if l == 0: + in_dim = bg_rank + enc_dim_dir + else: + in_dim = hidden_dim_bg + + if l == num_layers_bg - 1: + out_dim = 3 # 3 rgb + else: + out_dim = hidden_dim_bg + + bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) + + self.bg_net = MLP(enc_dim_dir, 3, hidden_dim_bg, num_layers_bg, bias=True) #nn.ModuleList(bg_net) + else: + self.bg_net = None + + + def init_one_svd(self, n_component, resolution, scale=0.1): + + mat, vec = [], [] + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + mat_id_0, mat_id_1 = self.mat_ids[i] + mat.append(nn.Parameter(scale * torch.randn((1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])))) # [1, R, H, W] + vec.append(nn.Parameter(scale * torch.randn((1, n_component[i], resolution[vec_id], 1)))) # [1, R, D, 1] (fake 2d to use grid_sample) + + return nn.ParameterList(mat), nn.ParameterList(vec) + + + def get_sigma_feat(self, x): + # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode) + + N = x.shape[0] + + # plane + line basis + mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord + + sigma_feat = torch.zeros([N,], device=x.device) + + for i in range(len(self.sigma_mat)): + mat_feat = F.grid_sample(self.sigma_mat[i], mat_coord[[i]], align_corners=True).view(-1, N) # [1, R, N, 1] --> [R, N] + vec_feat = F.grid_sample(self.sigma_vec[i], vec_coord[[i]], align_corners=True).view(-1, N) # [R, N] + sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0) + + return sigma_feat + + + def get_color_feat(self, x): + # x: [N, 3], in [-1, 1] + + N = x.shape[0] + + # plane + line basis + mat_coord = torch.stack((x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]])).view(3, -1, 1, 2) # [3, N, 1, 2] + vec_coord = torch.stack((x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])) + vec_coord = torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1).view(3, -1, 1, 2) # [3, N, 1, 2], fake 2d coord + + mat_feat, vec_feat = [], [] + + for i in range(len(self.color_mat)): + mat_feat.append(F.grid_sample(self.color_mat[i], mat_coord[[i]], align_corners=True).view(-1, N)) # [1, R, N, 1] --> [R, N] + vec_feat.append(F.grid_sample(self.color_vec[i], vec_coord[[i]], align_corners=True).view(-1, N)) # [R, N] + + mat_feat = torch.cat(mat_feat, dim=0) # [3 * R, N] + vec_feat = torch.cat(vec_feat, dim=0) # [3 * R, N] + + color_feat = self.basis_mat((mat_feat * vec_feat).T) # [N, 3R] --> [N, color_feat_dim] + + return color_feat + + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + # sigma + sigma_feat = self.get_sigma_feat(x) + sigma = trunc_exp(sigma_feat) + #sigma = F.softplus(sigma_feat - 3) + #sigma = F.relu(sigma_feat) + + # rgb + color_feat = self.get_color_feat(x) + enc_color_feat = self.encoder(color_feat) + enc_d = self.encoder_dir(d) + + h = torch.cat([enc_color_feat, enc_d], dim=-1) + for l in range(self.num_layers): + h = self.color_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + rgb = torch.sigmoid(h) + + return sigma, rgb, None + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + sigma_feat = self.get_sigma_feat(x) + sigma = trunc_exp(sigma_feat) + #sigma = F.softplus(sigma_feat - 3) + #sigma = F.relu(sigma_feat) + + return { + 'sigma': sigma, + } + + def background(self, d, x): + # x: [N, 2] in [-1, 1] + + #N = x.shape[0] + + #h = F.grid_sample(self.bg_mat, x.view(1, N, 1, 2), align_corners=True).view(-1, N).T.contiguous() # [R, N] --> [N, R] + #d = self.encoder_dir(d) + + #h = torch.cat([d, h], dim=-1) + #for l in range(self.num_layers_bg): + # h = self.bg_net[l](h) + # if l != self.num_layers_bg - 1: + # h = F.relu(h, inplace=True) + + h = self.encoder_dir(d) + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + + # allow masked inference + def color(self, x, d, mask=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + # normalize to [-1, 1] inside aabb_train + x = 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + + color_feat = self.get_color_feat(x) + color_feat = self.encoder(color_feat) + d = self.encoder_dir(d) + + h = torch.cat([color_feat, d], dim=-1) + for l in range(self.num_layers): + h = self.color_net[l](h) + if l != self.num_layers - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) + else: + rgbs = h + + return rgbs + + + # L1 penalty for loss + def density_loss(self): + loss = 0 + for i in range(len(self.sigma_mat)): + loss = loss + torch.mean(torch.abs(self.sigma_mat[i])) + torch.mean(torch.abs(self.sigma_vec[i])) + return loss + + # upsample utils + @torch.no_grad() + def upsample_params(self, mat, vec, resolution): + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + mat_id_0, mat_id_1 = self.mat_ids[i] + mat[i] = nn.Parameter(F.interpolate(mat[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode='bilinear', align_corners=True)) + vec[i] = nn.Parameter(F.interpolate(vec[i].data, size=(resolution[vec_id], 1), mode='bilinear', align_corners=True)) + + + @torch.no_grad() + def upsample_model(self, resolution): + self.upsample_params(self.sigma_mat, self.sigma_vec, resolution) + self.upsample_params(self.color_mat, self.color_vec, resolution) + self.resolution = resolution + + @torch.no_grad() + def shrink_model(self): + # shrink aabb_train and the model so it only represents the space inside aabb_train. + + half_grid_size = self.bound / self.grid_size + thresh = min(self.density_thresh, self.mean_density) + + # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) + valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] + valid_pos = raymarching.morton3D_invert(torch.nonzero(valid_grid)) # [Nz] --> [Nz, 3], in [0, H - 1] + #plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... + valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (self.bound - half_grid_size) # [Nz, 3], in [-b+hgs, b-hgs] + min_pos = valid_pos.amin(0) - half_grid_size # [3] + max_pos = valid_pos.amax(0) + half_grid_size # [3] + + # shrink model + reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) + units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso + tl = (min_pos - self.aabb_train[:3]) / units + br = (max_pos - self.aabb_train[:3]) / units + tl = torch.round(tl).long().clamp(min=0) + br = torch.minimum(torch.round(br).long(), reso) + + for i in range(len(self.vec_ids)): + vec_id = self.vec_ids[i] + mat_id_0, mat_id_1 = self.mat_ids[i] + + self.sigma_vec[i] = nn.Parameter(self.sigma_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + self.color_vec[i] = nn.Parameter(self.color_vec[i].data[..., tl[vec_id]:br[vec_id], :]) + + self.sigma_mat[i] = nn.Parameter(self.sigma_mat[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) + self.color_mat[i] = nn.Parameter(self.color_mat[i].data[..., tl[mat_id_1]:br[mat_id_1], tl[mat_id_0]:br[mat_id_0]]) + + self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] + + print(f'[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}') + print(f'[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}') + + + # optimizer utils + def get_params(self, lr1, lr2): + params = [ + {'params': self.sigma_mat, 'lr': lr1}, + {'params': self.sigma_vec, 'lr': lr1}, + {'params': self.color_mat, 'lr': lr1}, + {'params': self.color_vec, 'lr': lr1}, + {'params': self.basis_mat.parameters(), 'lr': lr2}, + {'params': self.color_net.parameters(), 'lr': lr2}, + ] + if self.bg_radius > 0: + params.append({'params': self.bg_mat, 'lr': lr1}) + params.append({'params': self.bg_net.parameters(), 'lr': lr2}) + return params \ No newline at end of file diff --git a/nerf/renderer.py b/nerf/renderer.py index 81c63c12..a8c5a5e6 100644 --- a/nerf/renderer.py +++ b/nerf/renderer.py @@ -96,6 +96,7 @@ def __init__(self, opt): self.opt = opt self.bound = opt.bound self.cascade = 1 + math.ceil(math.log2(opt.bound)) + self.time_size = 64 self.grid_size = 128 self.cuda_ray = opt.cuda_ray self.taichi_ray = opt.taichi_ray @@ -362,7 +363,7 @@ def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''): _export(v, f) - def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs): + def run(self, rays_o, rays_d, time, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # bg_color: [BN, 3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] @@ -410,7 +411,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, a #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB - density_outputs = self.density(xyzs.reshape(-1, 3)) + density_outputs = self.density(xyzs.reshape(-1, 3), t=time) #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] for k, v in density_outputs.items(): @@ -435,7 +436,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, a new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. # only forward new points to save computation - new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) + new_density_outputs = self.density(new_xyzs.reshape(-1, 3), t=time) #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, upsample_steps, -1) @@ -461,7 +462,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, a for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) - sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading) + sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, t=time, ratio=ambient_ratio, shading=shading) rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] if normals is not None: @@ -482,7 +483,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, a # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color - bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3] + bg_color = self.background(rays_d.reshape(-1, 3), rays_o) # [N, 3] elif bg_color is None: bg_color = 1 @@ -500,7 +501,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, a return results - def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + def run_cuda(self, rays_o, rays_d, time, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] @@ -529,7 +530,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, self.local_step += 1 xyzs, dirs, ts, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb, dt_gamma, max_steps) # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) - sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + sigmas, rgbs, normals = self(xyzs, dirs, light_d, t=time, ratio=ambient_ratio, shading=shading) weights, weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, ts, rays, T_thresh) # normals related regularizations @@ -569,7 +570,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, n_step = max(min(N // n_alive, 8), 1) xyzs, dirs, ts = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, perturb if step == 0 else False, dt_gamma, max_steps) - sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + sigmas, rgbs, normals = self(xyzs, dirs, light_d, t=time, ratio=ambient_ratio, shading=shading) raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh) rays_alive = rays_alive[rays_alive >= 0] @@ -580,7 +581,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color - bg_color = self.background(rays_d) # [N, 3] + bg_color = self.background(rays_d, rays_o) # [N, 3] elif bg_color is None: bg_color = 1 @@ -598,7 +599,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, return results - def run_taichi(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + def run_taichi(self, rays_o, rays_d, time, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] @@ -633,7 +634,7 @@ def run_taichi(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0 self.local_step += 1 rays_a, xyzs, dirs, deltas, ts, _ = self.ray_marching(rays_o, rays_d, hits_t[:, 0], self.density_bitfield, self.cascade, self.bound, exp_step_factor, self.grid_size, MAX_SAMPLES) # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) - sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + sigmas, rgbs, normals = self(xyzs, dirs, light_d, t=time, ratio=ambient_ratio, shading=shading) _, weights_sum, depth, image, weights = self.volume_render(sigmas, rgbs, deltas, ts, rays_a, kwargs.get('T_threshold', 1e-4)) # normals related regularizations @@ -690,7 +691,7 @@ def run_taichi(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0 rgbs = torch.zeros(len(xyzs), 3, device=device) normals = torch.zeros(len(xyzs), 3, device=device) - sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, ratio=ambient_ratio, shading=shading) + sigmas[valid_mask], _rgbs, normals = self(xyzs[valid_mask], dirs[valid_mask], light_d, t=time, ratio=ambient_ratio, shading=shading) rgbs[valid_mask] = _rgbs.float() sigmas = self.rearrange(sigmas, '(n1 n2) -> n1 n2', n2=n_step) rgbs = self.rearrange(rgbs, '(n1 n2) c -> n1 n2 c', n2=n_step) @@ -708,7 +709,7 @@ def run_taichi(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0 # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color - bg_color = self.background(rays_d) # [N, 3] + bg_color = self.background(rays_d, rays_o) # [N, 3] elif bg_color is None: bg_color = 1 @@ -729,6 +730,7 @@ def run_taichi(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0 @torch.no_grad() def update_extra_state(self, decay=0.95, S=128): + print('update_extra_state 1') # call before each epoch to update extra states. if not (self.cuda_ray or self.taichi_ray): @@ -741,28 +743,59 @@ def update_extra_state(self, decay=0.95, S=128): Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) - for xs in X: - for ys in Y: - for zs in Z: - - # construct points - xx, yy, zz = custom_meshgrid(xs, ys, zs) - coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) - indices = raymarching.morton3D(coords).long() # [N] - xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] - - # cascading - for cas in range(self.cascade): - bound = min(2 ** cas, self.bound) - half_grid_size = bound / self.grid_size - # scale to current cascade's resolution - cas_xyzs = xyzs * (bound - half_grid_size) - # add noise in [-hgs, hgs] - cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size - # query density - sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() - # assign - tmp_grid[cas, indices] = sigmas + if hasattr(self, 'times'): + for t, time in enumerate(self.times): + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + half_time_size = 0.5 / self.time_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in coord [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # add noise in time [-hts, hts] + time_perturb = time + (torch.rand_like(time) * 2 - 1) * half_time_size + # query density + sigmas = self.density(cas_xyzs, t=time_perturb)['sigma'].reshape(-1).detach() + sigmas *= self.density_scale + # assign + tmp_grid[t, cas, indices] = sigmas + else: + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # add noise in time [-hts, hts] + # query density + sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() + # assign + tmp_grid[cas, indices] = sigmas + # ema update valid_mask = self.density_grid >= 0 self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) @@ -785,7 +818,7 @@ def update_extra_state(self, decay=0.95, S=128): # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') - def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): + def render(self, rays_o, rays_d, time, staged=False, max_ray_batch=4096, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: pred_rgb: [B, N, 3] @@ -799,6 +832,7 @@ def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): B, N = rays_o.shape[:2] device = rays_o.device + # never stage when cuda_ray if staged and not (self.cuda_ray or self.taichi_ray): depth = torch.empty((B, N), device=device) @@ -821,6 +855,6 @@ def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): results['weights_sum'] = weights_sum else: - results = _run(rays_o, rays_d, **kwargs) + results = _run(rays_o, rays_d, time, **kwargs) return results diff --git a/nerf/utils.py b/nerf/utils.py index a0e8e064..007c9bfc 100644 --- a/nerf/utils.py +++ b/nerf/utils.py @@ -345,6 +345,9 @@ def train_step(self, data): rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] + time = None + if 'time' in data: + time = data['time'] # [B, 1] B, N = rays_o.shape[:2] H, W = data['H'], data['W'] @@ -368,7 +371,7 @@ def train_step(self, data): bg_color = None # bg_color = torch.rand((B * N, 3), device=rays_o.device) - outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) + outputs = self.model.render(rays_o, rays_d, time, staged=False, perturb=True, bg_color=bg_color, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) pred_depth = outputs['depth'].reshape(B, 1, H, W) if as_latent: @@ -421,6 +424,9 @@ def eval_step(self, data): rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] + time = None + if 'time' in data: + time = data['time'] # [B, 1] B, N = rays_o.shape[:2] H, W = data['H'], data['W'] @@ -429,7 +435,7 @@ def eval_step(self, data): ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 light_d = data['light_d'] if 'light_d' in data else None - outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) + outputs = self.model.render(rays_o, rays_d, time, staged=True, perturb=False, bg_color=None, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) pred_rgb = outputs['image'].reshape(B, H, W, 3) pred_depth = outputs['depth'].reshape(B, H, W) @@ -441,6 +447,9 @@ def eval_step(self, data): def test_step(self, data, bg_color=None, perturb=False): rays_o = data['rays_o'] # [B, N, 3] rays_d = data['rays_d'] # [B, N, 3] + time = None + if 'time' in data: + time = data['time'] # [B, 1] B, N = rays_o.shape[:2] H, W = data['H'], data['W'] @@ -454,7 +463,7 @@ def test_step(self, data, bg_color=None, perturb=False): ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 light_d = data['light_d'] if 'light_d' in data else None - outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt)) + outputs = self.model.render(rays_o, rays_d, time, staged=True, perturb=perturb, light_d=light_d, ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt)) pred_rgb = outputs['image'].reshape(B, H, W, 3) pred_depth = outputs['depth'].reshape(B, H, W)