NeRF代码理解

NeRF代码(基于Pytorch)。流程图:train流程图 and 渲染流程图 (基于Drawio)

train流程

if use_batching

config_parser()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
# 此处命令行更改的参数为 args.config
parser.add_argument('--config', is_config_file=True,
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/',
help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
help='input data directory')
# training options
parser.add_argument("--netdepth", type=int, default=8,
help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
help='channels per layer')
parser.add_argument("--lrate", type=float, default=5e-4,
help='learning rate')
parser.add_argument("--no_batching", action='store_true',
help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')

...
return parser


use:   
parser = config_parser()
args = parser.parse_args()
basedir = args.basedir

load_???_data()

load_llff_data()

输入:

  • args.datadir :’./data/llff/fern’
  • args.factor,
  • recenter=True,
  • bd_factor=.75,
  • spherify=args.spherify
  • path_zflat=False

输出:

load_blender_data(basedir, half_res=False, testskip=1)

输入:

  • basedir,数据集路径’E:\\3\\Work\dataset\\nerf_synthetic\\chair’

    [!info]- chair 文件夹

    • chair:
      • test
        • 200张png 800x800x4
      • train
        • 100张png
      • val
        • 100张png
      • .DS_Store
      • transforms_test.json
      • transforms_train.json
      • transforms_val.json
  • half_res,是否将图像缩小一倍 (下采样)
  • testskip,测试集跳着读取图像
    输出:
  • imgs:train、val、test,三个集的图像数据 imgs.shape : (400, 800, 800, 4)
  • poses:相机外参矩阵,相机位姿 poses.shape : (400, 4, 4) 400张图片的4x4相机外参矩阵
  • render_poses:渲染位姿,生成视频的相机位姿(torch.Size([40, 4, 4]):40帧)
  • [H, W, focal] 图片数据,高、宽、焦距
  • i_split,三个array数组
    1. 0,1,2,…,99 (100张train图像)
    2. 100,101,…,199 (100张val图像)
    3. 200,201,…399 (200张test图像)

create_nerf(args)

输入:

  • args,由命令行和默认设置的arguments共同组成的字典
    1
    2
    parser = config_parser()
    args = parser.parse_args()

输出:

  • render_kwargs_train
  • render_kwargs_test
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    render_kwargs_train = {
    'network_query_fn' : network_query_fn,
    'perturb' : args.perturb, # 默认为1 --perturb:在训练时对输入进行扰动
    'N_importance' : args.N_importance, # --N_importance:每条射线的附加精细采样数
    'network_fine' : model_fine,
    'N_samples' : args.N_samples, # --N_samples:每条射线的粗略采样数
    'network_fn' : model,
    'use_viewdirs' : args.use_viewdirs, # --use_viewdirs:使用全5D的输入代替3D的输入
    'white_bkgd' : args.white_bkgd,
    'raw_noise_std' : args.raw_noise_std, # 默认0 --raw_noise_std:添加到输入的噪声标准差
    }

    # NDC only good for LLFF-style forward facing data
    # NDC 全称是 Normalized Device Coordinates,即归一化设备坐标
    if args.dataset_type != 'llff' or args.no_ndc:
    print('Not ndc!')
    render_kwargs_train['ndc'] = False
    render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.
  • start :global step
  • grad_vars:model的参数列表,包括权重和偏置
    • grad_vars = list(model.parameters())
    • if args.N_importance > 0: grad_vars += list(model_fine.parameters())
  • optimizer
    1
    2
    3
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    if 加载了ckpt :  optimizer.load_state_dict(ckpt['optimizer_state_dict'])

get_embedder(args.multires, args.i_embed)

输入:

  • args.multires, 输入的L
  • args.i_embed,默认为0,使用位置编码,-1为无位置编码

输出:

  • embed, 位置编码函数,将(1024 * 32 * 64) * 3处理为 (1024 * 32 * 64) * 63
    • input_ch = 3 , L = 10 并且包括输入维度
  • embedder_obj.out_dim : 输出的维度

Embedder()

1
2
3
4
5
6
class Embedder:
    def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):

NeRF(nn.Module)

继承nn.Module构建的类

def init(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False)

Pytorch 中的 forward理解 - 知乎 (zhihu.com)

1
2
model = NeRF(...) 实例化
model(input) 相当于 model.forward(input)
  • D=args.netdepth, W=args.netwidth
  • input_ch=input_ch, output_ch=output_ch
  • skips=skips,
  • input_ch_views=input_ch_views
  • use_viewdirs=args.use_viewdirs

network_query_fn

1
2
3
4
network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
embed_fn=embed_fn,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)

if render_only:

render_path()

输入:

  • render_poses, 测试集的渲染相机位姿 200 4 4
  • hwf,
  • K, 相机内参矩阵
  • chunk, args.chunk=1024*32
  • render_kwargs = render_kwargs_test
  • gt_imgs=None,
  • savedir=None,
  • render_factor=0

输出:

  • rgbs, 200 x W x H x 3
  • disps,200xWxH
1
render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)

render()

1
rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)

get_rays_np(H, W, K, c2w) numpy版本(narray)

通过输入的图片大小和相机参数,得到从相机原点到图片每个像素的光线(方向向量d和相机原点o)

输入:

  • H:图片的高
  • W:图片的宽
  • K:相机内参矩阵
    1
    2
    3
    4
    5
    K = np.array([
    [focal, 0, 0.5*W],
    [0, focal, 0.5*H],
    [0, 0, 1]
    ])
  • c2w:相机外参矩阵
    1
    2
    3
    4
    5
    c2w = np.array([
    [ -0.9980267286300659, 0.04609514772891998, -0.042636688798666, -0.17187398672103882],
    [ -0.06279052048921585, -0.7326614260673523, 0.6776907444000244, 2.731858730316162],
    [-3.7252898543727042e-09, 0.6790306568145752,0.7341099381446838, 2.959291696548462],
    [ 0.0,0.0,0.0,1.0 ]])

输出:从相机原点到800x800图片中每个像素生成的光线 $r(t)=\textbf{o}+t\textbf{d}$

  • rays_o:光线原点(世界坐标系下)(800, 800, 3)
  • rays_d:光线的方向向量(世界坐标系下)(800, 800, 3)

需要对render输入的光线做batch: rays = batch_rays

render()

1
(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,near=0., far=1.,use_viewdirs=False, c2w_staticcam=None, **kwargs)

输入

  • H
  • W
  • K
  • chunk=args.chunk 同时处理光线的最大数量
  • rays=batch_rays [2, batch_size, 3]
  • verbose=i < 10
  • retraw=True
  • render_kwargs_train
1
2
3
4
5
6
7
8
9
10
11
render_kwargs_train = {
'network_query_fn' : network_query_fn,
'perturb' : args.perturb, # 默认为1 --perturb:在训练时对输入进行扰动,是否分层采样
'N_importance' : args.N_importance, # --N_importance:每条射线的附加精细采样数 = 128
'network_fine' : model_fine,
'N_samples' : args.N_samples, # --N_samples:每条射线的粗略采样数 64
'network_fn' : model,
'use_viewdirs' : args.use_viewdirs, # --use_viewdirs:使用全5D的输入代替3D的输入
'white_bkgd' : args.white_bkgd,
'raw_noise_std' : args.raw_noise_std, # 默认0 --raw_noise_std:添加到输入的噪声标准差
}

输出:ret_list + [ret_dict]

  • rgb: all_ret['rgb_map']
  • disp: all_ret['disp_map']
  • acc: all_ret['acc_map']
  • extras: [{'raw': raw , '...': ...}]
    eg: all_ret['raw'] : W * H * N_samples +N_importance * 4
    eg: all_ret['rgb_map'] : W * H * 3

渲染流程图

get_rays(H, W, K, c2w) torch版本(tensor)

通过输入的图片大小和相机参数,得到从相机原点到图片每个像素的光线(方向向量d和相机原点o)

输入:

  • H:图片的高
  • W:图片的宽
  • K:相机内参矩阵
    1
    2
    3
    4
    5
    K = np.array([
    [focal, 0, 0.5*W],
    [0, focal, 0.5*H],
    [0, 0, 1]
    ])
  • c2w:相机外参矩阵
    1
    2
    3
    4
    5
    c2w = np.array([
    [ -0.9980267286300659, 0.04609514772891998, -0.042636688798666, -0.17187398672103882],
    [ -0.06279052048921585, -0.7326614260673523, 0.6776907444000244, 2.731858730316162],
    [-3.7252898543727042e-09, 0.6790306568145752,0.7341099381446838, 2.959291696548462],
    [ 0.0,0.0,0.0,1.0 ]])

输出:从相机原点到800x800图片中每个像素生成的光线 $r(t)=\textbf{o}+t\textbf{d}$

  • rays_o:光线原点(世界坐标系下)(800, 800, 3)
  • rays_d:光线的方向向量(世界坐标系下)(800, 800, 3)

ndc_rays(H, W, focal, near, rays_o, rays_d)

仅需要对LLFF做Projection 变换到NDC坐标系下

输入:

  • H
  • W
  • focal
  • near
  • rays_o
  • rays_d

输出:NDC坐标系下

  • rays_o
  • rays_d

batchify_rays(rays_flat, chunk=1024*32, **kwargs)

输入:

  • rays_flat: (WxH)x8 or (WxH)x11
  • chunk=1024*32
  • **kwargs

输出:

  • all_ret: 将ret = {‘rgb_map’ : rgb_map, ‘disp_map’ : disp_map, ‘acc_map’ : acc_map} 拼接起来

ret:chunk(1024 * 32) —> all_ret:800x800

render_rays()

输入:

  • ray_batch
  • 继承来自render()输入的参数:
  • network_fn,
  • network_query_fn,
  • N_samples,
  • retraw=False,
  • lindisp=False,
  • perturb=0.,
  • N_importance=0, 每条光线增加的采样数
  • network_fine=None,
  • white_bkgd=False,
  • raw_noise_std=0.,
  • verbose=False,
  • pytest=False

输出:
ret = {‘rgb_map’ : rgb_map, ‘disp_map’ : disp_map, ‘acc_map’ : acc_map}

1
2
3
4
5
6
7
8
if retraw:
ret['raw'] = raw

if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]

run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64)

输入:

  • inputs
  • viewdirs
  • fn:network_fn = model ,经过一次MLP训练
  • embed_fn
  • embeddirs_fn
  • netchunk=1024*64

输出:

  • outputs: chunk N_samples 4 (每条光线,每个采样点的RGBσ)

在这里调用了run_network函数

1
2
3
4
    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk)

batchify(fn, chunk)

输入:

  • fn : model
  • chunk : 1024* 32

如果chunk没有数据,则返回fn,否则:

1
2
3
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
        # 每个chunk的大小为1024*32,即每次fn处理1024*32个点,最后返回结果

raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False)

体渲染在其中
chunk = N_rays,射线数
输入:

  • raw: chunk N_samples 4
  • z_vals: chunk * N_samples
  • rays_d: chunk * 3
  • raw_noise_std=0,
  • white_bkgd=False,
  • pytest=False)

输出:

  • rgb_map, [chunk, 3]
  • disp_map: [chunk]
  • acc_map: [chunk]
  • weights: [chunk, N_samples]
  • depth_map: [chunk]

sample_pdf(bins, weights, N_samples, det=False, pytest=False)

输入:

  • bins, chunk * 63
  • weights, chunk * 62
  • N_samples = N_importance
  • det=False = (perturb\==0.)
  • pytest=False

image.png|555

输出:

  • samples: chunk * N_importance

img2mse() and mse2psnr()

1
2
img2mse = lambda x, y : torch.mean((x - y) ** 2)
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))

数学公式:

$\text{MSE}(x, y) = \frac{1}{N}\sum_{i=1}^{N}(x_i - y_i)^2$

$\text{PSNR}(x) = -10 \cdot \frac{\ln(x)}{\ln(10)}$

Welcome to my other publishing channels