1 Star 2 Fork 0

OpenXRLab / xrnerf

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
config.md 7.80 KB
一键复制 编辑 原始数据 按行查看 历史
ZZL 提交于 2023-01-30 17:30 . fix formatting issues (#27)

教程 1: 如何编写配置文件

XRNeRF 使用 python 文件作为配置文件。其配置文件系统的设计将模块化与继承整合进来,方便用户进行各种实验。 XRNeRF 提供的所有配置文件都放置在 $PROJECT/configs 文件夹下。

配置文件组成部分

配置文件的内容在逻辑上可以分为3个部分:

  • 训练
  • 模型
  • 数据

下面的内容将会逐部分介绍配置文件

  • 训练 训练配置部分包含了控制训练过程的各类参数,包括optimizer, hooks, runner等等

    import os
    from datetime import datetime
    
    method = 'nerf' # nerf方法
    
    # optimizer 参数
    optimizer = dict(type='Adam', lr=5e-4, betas=(0.9, 0.999))
    optimizer_config = dict(grad_clip=None)
    
    max_iters = 20000 # 训练多少个iter
    lr_config = dict(policy='step', step=500 * 1000, gamma=0.1, by_epoch=False) # 学习率和衰减
    checkpoint_config = dict(interval=5000, by_epoch=False) # 保存checkpoint的间隔
    log_level = 'INFO'
    log_config = dict(interval=5000,
                    by_epoch=False,
                    hooks=[dict(type='TextLoggerHook')])
    workflow = [('train', 5000), ('val', 1)] # 循环: 每训练 5000 iters, 验证 1 iter
    
    # hooks
    # 'params' 是数值型参数, 'variables' 是代码运行上下面出现的变量
    train_hooks = [
        dict(type='SetValPipelineHook',
            params=dict(),
            variables=dict(valset='valset')),
        dict(type='ValidateHook',
            params=dict(save_folder='visualizations/validation')),
        dict(type='SaveSpiralHook',
            params=dict(save_folder='visualizations/spiral')),
        dict(type='PassIterHook', params=dict()),  # 将当前iter数告诉dataset
    ]
    test_hooks = [
        dict(type='SetValPipelineHook',
            params=dict(),
            variables=dict(valset='testset')),
        dict(type='TestHook', params=dict()),
    ]
    
    # runner
    train_runner = dict(type='NerfTrainRunner')
    test_runner = dict(type='NerfTestRunner')
    
    # runtime settings
    num_gpus = 1
    distributed = (num_gpus > 1)  # 是否使用 ddp
    work_dir = './work_dirs/nerfsv3/nerf_#DATANAME#_base01/' # 保存运行时产生文件的位置
    timestamp = datetime.now().strftime('%d-%b-%H-%M') # 保证每次的workspace都不同
    
    # some shared params by model and data, to avoid define twice
    dataset_type = 'blender'
    no_batching = True  # 每次选择1张图片来抽取射线
    no_ndc = True
    
    white_bkgd = True  # 渲染时背景设定为全白
    is_perturb = True  # set to 0. for no jitter, 1. for jitter
    use_viewdirs = True  # use full 5D input instead of 3D
    N_rand_per_sampler = 1024 * 4  # 在取多少根射线 在 get_item() 函数中使用
    lindisp = False  # sampling linearly in disparity rather than depth
    N_samples = 64  # 在coarse模型中输入多少根射线
    
    # resume_from = os.path.join(work_dir, 'latest.pth')
    # load_from = os.path.join(work_dir, 'latest.pth')
    
  • 模型 模型部分的配置信息,定义了网络模型结构,一个network通常由embedder, mlp 和 render组成。

    model = dict(
        type='NerfNetwork', # network 类名字
        cfg=dict(
            phase='train',  # 'train' or 'test'
            N_importance=128,  # number of additional fine samples per ray
            is_perturb=is_perturb, # see above
            chunk=1024 * 32,  # mainly work for val, to avoid oom
            bs_data='rays_o',  # the data's shape indicates the real batch-size, this's also the num of rays
        ),
        mlp=dict(  # coarse mlp model
            type='NerfMLP', # mlp class name
            skips=[4],
            netdepth=8,  # layers in network
            netwidth=256,  # channels per layer
            netchunk=1024 * 32,  # to avoid oom
            output_ch=5,  # 5 if cfg.N_importance>0 else 4
            use_viewdirs=use_viewdirs,
            embedder=dict(
                type='BaseEmbedder', # embedder class name
                i_embed=0,  # set 0 for default positional encoding, -1 for none
                multires=10,  # log2 of max freq for positional encoding (3D location)
                multires_dirs=4,  # this is 'multires_views' in origin codes, log2 of max freq for positional encoding (2D direction)
            ),
        ),
        mlp_fine=dict(  # fine model
            type='NerfMLP',
            skips=[4],
            netdepth=8,
            netwidth=256,
            netchunk=1024 * 32,
            output_ch=5,
            use_viewdirs=use_viewdirs,
            embedder=dict(
                type='BaseEmbedder',
                i_embed=0,
                multires=10,
                multires_dirs=4,
            ),
        ),
        render=dict(
            type='NerfRender', # render cloass name
            white_bkgd=white_bkgd,  # see above
            raw_noise_std=0,  # std dev of noise added to regularize sigma_a output, 1e0 recommended
        ),
    )
  • 数据 数据部分的配置信息,定义了数据集类型,数据的处理流程,batchsize等等信息。

    basedata_cfg = dict(
        dataset_type=dataset_type,
        datadir='data/nerf_synthetic/#DATANAME#',
        half_res=True,  # load blender synthetic data at 400x400 instead of 800x800
        testskip=
        8,  # will load 1/N images from test/val sets, useful for large datasets like deepvoxels
        white_bkgd=white_bkgd,
        is_batching=False,  # True for blender, False for llff
        mode='train',
    )
    
    traindata_cfg = basedata_cfg.copy()
    valdata_cfg = basedata_cfg.copy()
    testdata_cfg = basedata_cfg.copy()
    
    traindata_cfg.update(dict())
    valdata_cfg.update(dict(mode='val'))
    testdata_cfg.update(dict(mode='test', testskip=0))
    
    train_pipeline = [
        dict(type='Sample'),
        dict(type='DeleteUseless', keys=['images', 'poses', 'i_data', 'idx']),
        dict(type='ToTensor', keys=['pose', 'target_s']),
        dict(type='GetRays'),
        dict(type='SelectRays',
            sel_n=N_rand_per_sampler,
            precrop_iters=500,
            precrop_frac=0.5),  # in the first 500 iter, select rays inside center of image
        dict(type='GetViewdirs', enable=use_viewdirs),
        dict(type='ToNDC', enable=(not no_ndc)),
        dict(type='GetBounds'),
        dict(type='GetZvals', lindisp=lindisp,
            N_samples=N_samples),  # N_samples: number of coarse samples per ray
        dict(type='PerturbZvals', enable=is_perturb),
        dict(type='GetPts'),
        dict(type='DeleteUseless', keys=['pose', 'iter_n']),
    ]
    
    test_pipeline = [
        dict(type='ToTensor', keys=['pose']),
        dict(type='GetRays'),
        dict(type='FlattenRays'),
        dict(type='GetViewdirs', enable=use_viewdirs),
        dict(type='ToNDC', enable=(not no_ndc)),
        dict(type='GetBounds'),
        dict(type='GetZvals', lindisp=lindisp, N_samples=N_samples),
        dict(type='PerturbZvals', enable=False),  # do not perturb when test
        dict(type='GetPts'),
        dict(type='DeleteUseless', keys=['pose']),
    ]
    data = dict(
        train_loader=dict(batch_size=1, num_workers=4),
        train=dict(
            type='SceneBaseDataset',
            cfg=traindata_cfg,
            pipeline=train_pipeline,
        ),
        val_loader=dict(batch_size=1, num_workers=0),
        val=dict(
            type='SceneBaseDataset',
            cfg=valdata_cfg,
            pipeline=test_pipeline,
        ),
        test_loader=dict(batch_size=1, num_workers=0),
        test=dict(
            type='SceneBaseDataset',
            cfg=testdata_cfg,
            pipeline=test_pipeline,  # same pipeline as validation
        ),
    )
1
https://gitee.com/OpenXRLab/xrnerf.git
git@gitee.com:OpenXRLab/xrnerf.git
OpenXRLab
xrnerf
xrnerf
main

搜索帮助

53164aa7 5694891 3bd8fe86 5694891