代码拉取完成,页面将自动刷新
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from .base import BaseDataset
from .builder import DATASETS
from .load_data import load_data, load_rays
@DATASETS.register_module()
class SceneBaseDataset(BaseDataset):
def __init__(self, cfg, pipeline):
super().__init__()
self.iter_n = 0
self.cfg = cfg
if 'mode' in cfg: self.mode = cfg.mode
if 'is_batching' in cfg: self.is_batching = cfg.is_batching
self._init_load()
self._init_pipeline(pipeline)
def _init_load(self): # load dataset when init
self.images, self.poses, self.render_poses, self.hwf, self.K, self.near, \
self.far, self.i_train, self.i_val, self.i_test = load_data(self.cfg)
if self.is_batching and self.mode == 'train':
# for batching dataset, rays must be computed when init()
self.N_rand = self.cfg.N_rand_per_sampler
self.rays_rgb = load_rays(self.hwf[0], self.hwf[1], self.K,
self.poses, self.images, self.i_train)
def get_info(self):
res = {
'H': self.hwf[0],
'W': self.hwf[1],
'focal': self.hwf[2],
'K': self.K,
'render_poses': self.render_poses,
'hwf': self.hwf,
'near': self.near,
'far': self.far
}
return res
def _fetch_train_data(self, idx):
if self.is_batching: # for batching dataset, rays are randomly selected from all images
data = {'rays_rgb': self.rays_rgb, 'idx': idx}
else: # for batching dataset, rays are selected from one images
data = {
'poses': self.poses,
'images': self.images,
'i_data': self.i_train,
'idx': idx
}
data['iter_n'] = self.iter_n
return data
def _fetch_val_data(self, idx): # for val mode, fetch all data in one time
data = {'spiral_poses':self.render_poses, 'poses':self.poses[self.i_test], \
'images':self.images[self.i_test]}
return data
def _fetch_test_data(
self, idx): # different from val: test return one image once
data = {'pose':self.poses[self.i_test][idx], 'image':self.images[self.i_test][idx], \
'idx':idx}
return data
def __getitem__(self, idx):
if self.mode == 'train':
data = self._fetch_train_data(idx)
data = self.pipeline(data)
return data
elif self.mode == 'val': # for some complex reasons,pipeline have to be moved to network.val_step() in val phase
return self._fetch_val_data(idx)
elif self.mode == 'test': # for some complex reasons,pipeline have to be moved to network.val_step() in test phase
data = self._fetch_test_data(idx)
return data
def __len__(self):
if self.mode == 'train':
if self.is_batching:
return self.rays_rgb.shape[0] // self.N_rand
else:
return self.i_train.shape[0]
elif self.mode == 'val':
return 1
elif self.mode == 'test':
return self.i_test.shape[0]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。