Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
### training options
parser.add_argument('--iters', type=int, default=10000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-3, help="initial learning rate")
parser.add_argument('--lr2', type=float, default=1e-3, help="initial learning rate 2")
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--max_steps', type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)")
Expand Down Expand Up @@ -166,11 +167,20 @@ def submit(text, iters, seed):
if opt.optim == 'adan':
from optimizer import Adan
# Adan usually requires a larger LR
optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-15)
if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf':
optimizer = lambda model: Adan(model.get_params(5 * opt.lr, 5 * opt.lr2), eps=1e-15)
else:
optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-15)
elif opt.optim == 'adamw':
optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf':
optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
else:
optimizer = lambda model: torch.optim.AdamW(model.get_params(opt.lr, opt.lr2), betas=(0.9, 0.99), eps=1e-15)
else: # adam
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf':
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr, opt.lr2), betas=(0.9, 0.99), eps=1e-15)
else:
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)

scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed

Expand Down
18 changes: 15 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
### training options
parser.add_argument('--iters', type=int, default=10000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-3, help="max learning rate")
parser.add_argument('--lr2', type=float, default=1e-3, help="max learning rate 2")
parser.add_argument('--warm_iters', type=int, default=500, help="training iters")
parser.add_argument('--min_lr', type=float, default=1e-4, help="minimal learning rate")
parser.add_argument('--ckpt', type=str, default='latest')
Expand All @@ -49,7 +50,7 @@
parser.add_argument('--blob_density', type=float, default=10, help="max (center) density for the density blob")
parser.add_argument('--blob_radius', type=float, default=0.5, help="control the radius for the density blob")
# network backbone
parser.add_argument('--backbone', type=str, default='grid', choices=['grid', 'vanilla', 'grid_taichi'], help="nerf backbone")
parser.add_argument('--backbone', type=str, default='grid', choices=['grid', 'vanilla', 'grid_taichi', 'tensoRF', 'dnerf'], help="nerf backbone")
parser.add_argument('--optim', type=str, default='adan', choices=['adan', 'adam'], help="optimizer")
parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
Expand Down Expand Up @@ -103,6 +104,11 @@
from nerf.network import NeRFNetwork
elif opt.backbone == 'grid':
from nerf.network_grid import NeRFNetwork
elif opt.backbone == 'tensoRF':
from nerf.network_tensorf import NeRFNetwork
elif opt.backbone == 'dnerf':
opt.cuda_ray = False
from nerf.network_dnerf import NeRFNetwork
elif opt.backbone == 'grid_taichi':
opt.cuda_ray = False
opt.taichi_ray = True
Expand Down Expand Up @@ -151,9 +157,15 @@
if opt.optim == 'adan':
from optimizer import Adan
# Adan usually requires a larger LR
optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf':
optimizer = lambda model: Adan(model.get_params(5 * opt.lr, 5 * opt.lr2), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
else:
optimizer = lambda model: Adan(model.get_params(5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False)
else: # adam
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)
if opt.backbone == 'tensoRF' or opt.backbone == 'dnerf':
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr, opt.lr2), betas=(0.9, 0.99), eps=1e-15)
else:
optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)

if opt.backbone == 'vanilla':
warm_up_with_cosine_lr = lambda iter: iter / opt.warm_iters if iter <= opt.warm_iters \
Expand Down
2 changes: 1 addition & 1 deletion nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def density(self, x):
}


def background(self, d):
def background(self, d, x):

h = self.encoder_bg(d) # [N, C]

Expand Down
308 changes: 308 additions & 0 deletions nerf/network_dnerf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from activation import trunc_exp
from .renderer import NeRFRenderer

from encoding import get_encoder

class MLP(nn.Module):
def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
self.dim_hidden = dim_hidden
self.num_layers = num_layers

net = []
for l in range(num_layers):
net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias))

self.net = nn.ModuleList(net)

def forward(self, x):
for l in range(self.num_layers):
x = self.net[l](x)
if l != self.num_layers - 1:
x = F.relu(x, inplace=True)
return x

class NeRFNetwork(NeRFRenderer):
def __init__(self,
opt,
encoding="tiledgrid",
encoding_dir="sphere_harmonics",
encoding_time="frequency",
encoding_deform="frequency", # "hashgrid" seems worse
encoding_bg="hashgrid",
num_layers=2,
hidden_dim=64,
geo_feat_dim=15,
num_layers_color=3,
hidden_dim_color=64,
num_layers_bg=2,
hidden_dim_bg=64,
num_layers_deform=5, # a deeper MLP is very necessary for performance.
hidden_dim_deform=128
):

super().__init__(opt)
print('dnerf')
self.bound = opt.bound

# deformation network
self.num_layers_deform = num_layers_deform
self.hidden_dim_deform = hidden_dim_deform
self.encoder_deform, self.in_dim_deform = get_encoder(encoding_deform, multires=10)
self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6)
self.density_scale = 1
# time stamps for density grid
self.time_size = 64
self.times = ((torch.arange(self.time_size, dtype=torch.float32) + 0.5) / self.time_size).view(-1, 1, 1) # [T, 1, 1]
self.density_grid = torch.zeros(self.time_size, self.cascade, self.grid_size ** 3) # [T, CAS, H * H * H]

deform_net = []
for l in range(num_layers_deform):
if l == 0:
in_dim = self.in_dim_deform + self.in_dim_time # grid dim + time
else:
in_dim = hidden_dim_deform

if l == num_layers_deform - 1:
out_dim = 3 # deformation for xyz
else:
out_dim = hidden_dim_deform

deform_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.deform_net = nn.ModuleList(deform_net)


# sigma network
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.geo_feat_dim = geo_feat_dim
self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * self.bound)

sigma_net = []
for l in range(num_layers):
if l == 0:
in_dim = self.in_dim + self.in_dim_time + self.in_dim_deform # concat everything
else:
in_dim = hidden_dim

if l == num_layers - 1:
out_dim = 1 + self.geo_feat_dim # 1 sigma + features for color
else:
out_dim = hidden_dim

sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.sigma_net = nn.ModuleList(sigma_net)

# color network
self.num_layers_color = num_layers_color
self.hidden_dim_color = hidden_dim_color
self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir)

color_net = []
for l in range(num_layers_color):
if l == 0:
in_dim = self.in_dim_dir + self.geo_feat_dim
else:
in_dim = hidden_dim_color

if l == num_layers_color - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim_color

color_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.color_net = nn.ModuleList(color_net)

# background network
if self.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid

bg_net = []
for l in range(num_layers_bg):
if l == 0:
in_dim = self.in_dim_bg + self.in_dim_dir
else:
in_dim = hidden_dim_bg

if l == num_layers_bg - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim_bg

bg_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) # nn.ModuleList(bg_net)
else:
self.bg_net = None

print('dnerf init done')

def forward(self, x, d, t=None, l=None, ratio=1, shading='albedo'):
print('forward 1')
# x: [N, 3], in [-bound, bound]
# d: [N, 3], nomalized in [-1, 1]
# t: [1, 1], in [0, 1]

# deform
enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C]
enc_t = self.encoder_time(t) # [1, 1] --> [1, C']
if enc_t.shape[0] == 1:
enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C']

deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C']
for l in range(self.num_layers_deform):
deform = self.deform_net[l](deform)
if l != self.num_layers_deform - 1:
deform = F.relu(deform, inplace=True)

x = x + deform

# sigma
x = self.encoder(x, bound=self.bound)
h = torch.cat([x, enc_ori_x, enc_t], dim=1)
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)

#sigma = F.relu(h[..., 0])
sigma = trunc_exp(h[..., 0])
geo_feat = h[..., 1:]

# color
d = self.encoder_dir(d)
h = torch.cat([d, geo_feat], dim=-1)
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)

# sigmoid activation for rgb
rgbs = torch.sigmoid(h)

print('foward return')
return sigma, rgbs, deform

def density(self, x, t=None):
print('density 1')
# x: [N, 3], in [-bound, bound]
# t: [1, 1], in [0, 1]

results = {}

# deformation
enc_ori_x = self.encoder_deform(x, bound=self.bound) # [N, C]
enc_t = self.encoder_time(t) # [1, 1] --> [1, C']
if enc_t.shape[0] == 1:
enc_t = enc_t.repeat(x.shape[0], 1) # [1, C'] --> [N, C']

deform = torch.cat([enc_ori_x, enc_t], dim=1) # [N, C + C']
for l in range(self.num_layers_deform):
deform = self.deform_net[l](deform)
if l != self.num_layers_deform - 1:
deform = F.relu(deform, inplace=True)

x = x + deform
results['deform'] = deform

# sigma
x = self.encoder(x, bound=self.bound)
h = torch.cat([x, enc_ori_x, enc_t], dim=1)
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)

#sigma = F.relu(h[..., 0])
sigma = trunc_exp(h[..., 0])
geo_feat = h[..., 1:]

results['sigma'] = sigma
results['geo_feat'] = geo_feat

print('density return')
return results

def background(self, d, x):
print('background 1')
# x: [N, 2], in [-1, 1]

h = self.encoder_bg(x) # [N, C]
d = self.encoder_dir(d)

h = torch.cat([d, h], dim=-1)
for l in range(self.num_layers_bg):
h = self.bg_net[l](h)
if l != self.num_layers_bg - 1:
h = F.relu(h, inplace=True)

# sigmoid activation for rgb
rgbs = torch.sigmoid(h)

print('background return')
return rgbs

# allow masked inference
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
print('color 1')
# x: [N, 3] in [-bound, bound]
# t: [1, 1], in [0, 1]
# mask: [N,], bool, indicates where we actually needs to compute rgb.

if mask is not None:
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
# in case of empty mask
if not mask.any():
return rgbs
x = x[mask]
d = d[mask]
geo_feat = geo_feat[mask]

d = self.encoder_dir(d)
h = torch.cat([d, geo_feat], dim=-1)
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)

# sigmoid activation for rgb
h = torch.sigmoid(h)

if mask is not None:
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
else:
rgbs = h

print('color return')
return rgbs

# optimizer utils
def get_params(self, lr, lr_net):
print('get_params 1')

params = [
{'params': self.encoder.parameters(), 'lr': lr},
{'params': self.sigma_net.parameters(), 'lr': lr_net},
{'params': self.encoder_dir.parameters(), 'lr': lr},
{'params': self.color_net.parameters(), 'lr': lr_net},
{'params': self.encoder_deform.parameters(), 'lr': lr},
{'params': self.encoder_time.parameters(), 'lr': lr},
{'params': self.deform_net.parameters(), 'lr': lr_net},
]
if self.bg_radius > 0:
params.append({'params': self.encoder_bg.parameters(), 'lr': lr})
params.append({'params': self.bg_net.parameters(), 'lr': lr_net})

print('get_params return')
return params
Loading