从零实现一个3D目标检测算法(2):点云数据预处理
在上一篇文章《從零實現一個3D目標檢測算法(1):3D目標檢測概述》對3D目標檢測研究現狀和PointPillars模型進行了介紹,在本文中我們開始寫代碼一步步實現PointPillars,這里我們先實現如何對點云數據進行預處理。
在圖像目標檢測中,一般不需要對圖像進行預處理操作,直接輸入原始圖像即可得到最終的檢測結果。
但是在點云3D目標檢測中,往往需要對點云進行一定的預處理,本文將介紹在PointPillars模型中如何對點云進行預處理。這里的點云數據預處理操作同樣也適用其它的基于Voxels的3D檢測模型中。
文章目錄
- 1. 模型配置文件config.py
- 1.1 將模型參數保存在日志文件
- 1.2 加載模型配置文件
- 1.3 解析終端命令修改模型配置參數
- 2. 點云數據預處理
- 2.1 DatasetTemplate類
- 2.2 KittiDataset類
- 2.3 KITTI數據加載器
1. 模型配置文件config.py
在這里我們將首先編寫在整個工程中最重要的config.py文件,該文件主要包括三個函數。
作用分別是:加載模型配置文件pointpillar.yaml、將模型參數保存在日志文件中、以及解析終端命令修改模型配置參數。
關于上述三個函數,只需要會使用即可。首先導入需要的Python庫:
from easydict import EasyDict from pathlib import Path import yaml1.1 將模型參數保存在日志文件
這一部分是將整個網絡模型的全部參數保存到日志文件中,在模型訓練過程中每一個模塊的代碼往往會修改很多次。有了日志文件,我們就能很方便地查看每次所修改的地方,如果有疑問的話,可以借助日志文件快速定位問題,代碼如下:
def log_config_to_file(cfg, pre='cfg', logger=None):for key, val in cfg.items():if isinstance(cfg[key], EasyDict):logger.info('\n%s.%s = edict()' % (pre, key))log_config_to_file(cfg[key], pre=pre+ '.' + key, logger=logger)continuelogger.info('%s.%s: %s' % (pre, key, val))1.2 加載模型配置文件
下面一個函數是從配置文件pointpillar.yaml中加載網絡模型參數。
在Python中我們使用字典這種數據類型來存儲網絡的各種參數,只需要命名好參數名稱即可,如測試集,訓練集名稱,網絡各子模塊名稱,損失函數名稱等,在修改時也只需要修改對應參數的變量值即可,這是一個很方便的調參方式。代碼如下:
def cfg_from_yaml_file(cfg_file, config):with open(cfg_file, 'r') as f:try:new_config = yaml.load(f, Loader=yaml.FullLoader)except:new_config = yaml.load(f)config.update(EasyDict(new_config))return config1.3 解析終端命令修改模型配置參數
除了對模型配置文件.yaml進行修改外,也可以在執行時通過終端來修改模型的參數。
這時就要求程序能夠獲取終端信息,包括參數名稱以及參數值,通常是成對出現,代碼如下:
def cfg_from_list(cfg_list, config):"""Set config keys via list (e.g., from command line)."""from ast import literal_evalassert len(cfg_list) % 2 == 0for k, v in zip(cfg_list[0::2], cfg_list[1::2]):key_list = k.split('.')d = configfor subkey in key_list[:-1]:assert subkey in d, 'NotFoundKey: %s' % subkeyd = d[subkey]subkey = key_list[-1]assert subkey in d, 'NotFoundKey: %s' % subkeytry: value = literal_eval(v)except:value = v if type(value) != type(d[subkey]) and isinstance(d[subkey], EasyDict):key_val_list = value.split('.')for src in key_val_list:cur_key, cur_val = src.split(':')val_type = type(d[subkey][cur_key])cur_val = val_type(cur_val)d[subkey][cur_key] = cur_valelif type(value) != type(d[subkey]) and isinstance(d[subkey], list):val_list = value.split('.')for k, x in enumerate(val_list):val_list[k] = type(d[subkey][0])(x)d[subkey] = val_listelse:assert type(value) == type(d[subkey]), \'type {} dose not match original type {}'.format(type(value), type(d[subkey]))d[subkey] = value下面我們來定義模型參數配置變量cfg,其本身是一個字典,現在我們先定義它的根路徑。
至此配置文件代碼編寫完畢,不妨可以調用cfg_from_yaml_file函數加載yaml文件看看模型參數加載是否正確。
cfg = EasyDict() cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve() cfg.LOCAL_RANK = 0if __name__=='__main__':pass2. 點云數據預處理
現在我們對KITTI數據集進行預處理,最終將其加載到PyTorch的DataLoader中。
2.1 DatasetTemplate類
首先是dataset.py文件,我們使用Python中的Class來對點云數據進行預處理,數據的預處理操作都定義為Class的成員函數。
先首先定義一個DatasetTemplate類,當做點云數據的一個基本類,后面處理其它點云數據集時可以在此基礎上進行不同的操作,導入必要的Python庫:
import numpy as np from collections import defaultdict import torch.utils.data as torch_data import sys sys.path.append('../') sys.path.append('../../') from utils import common_utils from config import cfgclass DatasetTemplate(torch_data.Dataset):def __init__(self):super().__init__()在DatasetTemplate中我們定義兩個成員函數,一個是數據準備函數prepare_data。
輸入的是點云數據幀編號(idx)(idx)(idx)和原始點云數據(N,3+C1)(N,3+C1)(N,3+C1),以字典形式傳輸,輸出為:
- Voxels
- Voxels坐標
- 每個Voxels中點的個數
- Voxels中心坐標(全局坐標)
- 原始點云數據
輸出同樣以字典形式輸出。
def prepare_data(self, input_dict):""":param input_dict:sample_idx: stringpoints: (N, 3 + C1):return:voxels: (N, max_points_of_each_voxel, 3 + C2), floatnum_points: (N), intcoordinates: (N, 3), [idx_z, idx_y, idx_x]voxel_centers: (N, 3)points: (M, 3 + C)"""sample_idx = input_dict['sample_idx']points = input_dict['points']points = points[:, :cfg.DATA_CONFIG.NUM_POINT_FEATURES['use']] # voxels, coordinates, num_pointsvoxels, coordinates, num_points = self.voxel_generator.generate(points, \max_voxels=cfg.DATA_CONFIG[self.mode].MAX_NUMBER_OF_VOXELS) # voxel_centersvoxel_centers = (coordinates[:, ::-1] + 0.5) * self.voxel_generator.voxel_size \+ self.voxel_generator.point_cloud_range[0:3]print('voxel_centers.shape is: ', voxel_centers.shape) # (11719, 3)if cfg.DATA_CONFIG.MASK_POINTS_BY_RANGE:points = common_utils.mask_points_by_range(points, cfg.DATA_CONFIG.POINT_CLOUD_RANGE)example = {}example.update({'voxels': voxels,'num_points': num_points,'coordinates': coordinates,'voxel_centers': voxel_centers,'points': points})return example另一個函數是collate_batch,作用是在加載數據集時如何選取數據。
@staticmethod def collate_batch(batch_list, _unused=False):example_merged = defaultdict(list)for example in batch_list:for k, v in example.items():example_merged[k].append(v)ret = {}for key, elems in example_merged.items():if key in ['voxels', 'num_points', 'voxel_centers', 'seg_labels', 'part_labels', 'bbox_reg_labels']:ret[key] = np.concatenate(elems, axis=0)elif key in ['coordinates', 'points']:coors = []for i, coor in enumerate(elems):coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)coors.append(coor_pad)ret[key] = np.concatenate(coors, axis=0)elif key in ['gt_boxes']:max_gt = 0batch_size = elems.__len__()for k in range(batch_size):max_gt = max(max_gt, elems[k].__len__())batch_gt_boxes3d = np.zeros((batch_size, max_gt, elems[0].shape[-1]), dtype=np.float32)for k in range(batch_size):batch_gt_boxes3d[k, :elems[k].__len__(), :] = elems[k]ret[key] = batch_gt_boxes3delse:ret[key] = np.stack(elems, axis=0)ret['batch_size'] = batch_list.__len__()return ret2.2 KittiDataset類
現在我們編寫kitti_dataset.py,主要目的是創造KittiDataset類,首先是導入所需庫:
import os import sys import pickle import copy import numpy as np from pathlib import Path import torch import sys sys.path.append('../') sys.path.append('../../') from config import cfg from spconv.utils import VoxelGenerator from ..dataset import DatasetTemplate在這里我們首先定義一個BaseKittiDataset類,這里初始化只有一個參數,就是點云數據的存儲路徑root_path。
class BaseKittiDataset(DatasetTemplate):def __init__(self, root_path):super().__init__()self.root_path = root_path現在我們編寫獲取點云數據的get_lidar函數,KITTI中點云數據是以二進制格式保存的,每個點有4個信息:(x,y,z,r)(x,y,z,r)(x,y,z,r),數據類型為float32,代碼如下:
def get_lidar(self, idx):lidar_file = os.path.join(self.root_path, 'velodyne', '%06d.bin' % idx)assert os.path.exists(lidar_file)return np.fromfile(lidar_file, dtype=np.float32).reshape([-1, 4])此外我們也可以編寫函數get_infos來獲取點云信息,具體為:
def get_infos(self, idx):import concurrent.futures as futuresinfo = {}pc_info = {'num_features':4, 'lidar_idx': idx}info['point_cloud'] = pc_inforeturn info這里有一個生成最終預測結果的函數,因為模型計算時使用的是GPU,而要保存時需要轉化為CPU可訪問的數據。
預測信息有box尺寸box3d_lidar,分值scores,目標類型標簽label_preds,以及點云編號sample_idx。
@staticmethod def generate_prediction_dict(input_dict, index, record_dict):# finally generate predictions.sample_idx = input_dict['sample_idx'][index] if 'sample_idx' in input_dict else -1boxes3d_lidar_preds = record_dict['boxes'].cpu().numpy()if boxes3d_lidar_preds.shape[0] == 0:return {'sample_idx': sample_idx}predictions_dict ={'box3d_lidar': boxes3d_lidar_preds,'scores': record_dict['scores'].cpu.numpy(),'label_preds': record_dict['labels'].cpu().numpy(),'sample_idx': sample_idx}return predictions_dict現在我們就可以創建KittiDataset類了,同樣初始化時需要設置數據路徑,這里我們需要將模式設置為TEST:
class KittiDataset(BaseKittiDataset):def __init__(self, root_path, logger=None):super().__init__(root_path=root_path)self.logger = loggerself.mode = 'TEST'self.kitti_infos = []self.include_kitti_data(self.mode, logger)self.dataset_init(logger)在初始化時,有一個dataset_init函數,這個函數是用來生成voxel_generator的,使用的庫為Spconv,在上面的prepare_data函數中會使用這個voxel_generator,代碼如下:
def dataset_init(self, logger):voxel_generator_cfg = cfg.DATA_CONFIG.VOXEL_GENERATORself.voxel_generator = VoxelGenerator(voxel_size=voxel_generator_cfg.VOXEL_SIZE,point_cloud_range=cfg.DATA_CONFIG.POINT_CLOUD_RANGE,max_num_points=voxel_generator_cfg.MAX_POINTS_PER_VOXEL)include_kitti_data函數是用來加載pkl文件的,我們會將待處理的點云信息存儲在pkl文件中,這樣測試模型時只需使用這一個文件就可以訪問全部點云數據了:
def include_kitti_data(self, mode, logger):if cfg.LOCAL_RANK == 0 and logger is not None:logger.info('Loading KITTI dataset')kitti_infos = []for info_path in cfg.DATA_CONFIG[mode].INFO_PATH: info_path = cfg.ROOT_DIR / info_pathwith open(info_path, 'rb') as f:infos = pickle.load(f)kitti_infos.append(infos)self.kitti_infos.extend(kitti_infos)if cfg.LOCAL_RANK == 0 and logger is not None:logger.info('Total samples for KITTI dataset: %d' % (len(kitti_infos)))此外我們也可以對點云進行篩選,下面的代碼為選取xxx在[0,70.4][0, 70.4][0,70.4],yyy在[?40,40][-40, 40][?40,40],zzz在[?3,1][-3, 1][?3,1]范圍的點,這個一般要根據具體應用場景來設置。
@staticmethod def get_valid_flag(pts_lidar):'''Valid points should be in the PC_AREA_SCOPE'''val_flag_x = np.logical_and(pts_lidar[:, 0]>=0, pts_lidar[:, 0]<=70.4)val_flag_y = np.logical_and(pts_lidar[:, 1]>=-40, pts_lidar[:, 1]<=40)val_flag_z = np.logical_and(pts_lidar[:, 2]>=-3, pts_lidar[:, 2]<=1)val_flag_merge = np.logical_and(val_flag_x, val_flag_y, val_flag_z)pts_valid_flag = val_flag_mergereturn pts_valid_flag然后,就是編寫__getitem__函數:
def __len__(self):return len(self.kitti_infos)def __getitem__(self, index):info = copy.deepcopy(self.kitti_infos[index])sample_idx = info['point_cloud']['lidar_idx']points = self.get_lidar(sample_idx)pts_valid_flag = self.get_valid_flag(points[:, 0:3])points = points[pts_valid_flag]input_dict = {'points': points, 'sample_idx': sample_idx}example = self.prepare_data(input_dict=input_dict)example['sample_idx'] = sample_idxreturn example下面是create_kitti_infos函數:
def create_kitti_infos(data_path, save_path):dataset = BaseKittiDataset(root_path=data_path)val_filename = save_path / ('kitti_infos_val.pkl')print('val_filename is: ', val_filename)print('---------------Start to generate data infos---------------')kitti_infos_val = dataset.get_infos(idx)print(kitti_infos_val)with open(val_filename, 'wb') as f:pickle.dump(kitti_infos_val, f)print('Kitti info val file is saved to %s' % val_filename)最后編寫main函數,函數主要作用是獲取終端信息,生成kitti_infos。
if __name__=='__main__':if sys.argv.__len__() > 1 and sys.argv[1] == 'create_kitti_infos':create_kitti_infos(data_path=cfg.ROOT_DIR / 'data',save_path=cfg.ROOT_DIR / 'data')生成后的kitti_infos如下:
{'point_cloud': {'num_features': 4, 'lidar_idx': '000010'}}2.3 KITTI數據加載器
現在編寫__init__.py,這里的作用是通過DataLoader加載點云數據,這在PyTorch是十分常見的,代碼如下:
import os from pathlib import Path import torch from torch.utils.data import DataLoader from .kitti.kitti_dataset import KittiDataset, BaseKittiDataset from config import cfg__all__ = {'BaseKittiDataset': BaseKittiDataset,'KittiDataset': KittiDataset}def build_dataloader(data_dir, batch_size, logger=None):data_dir = Path(data_dir) if os.path.isabs(data_dir) else cfg.ROOT_DIR / data_dirdataset = __all__[cfg.DATA_CONFIG.DATASET](root_path=data_dir, logger=logger)dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=dataset.collate_batch, drop_last=False)return dataset, dataloader至此,點云數據預處理部分我們就已經完成了,預處理后的點云數據將變成如下形式:
- 原始points
- voxels及其坐標
- voxels中心位置
- 點的數量
- 點云幀編號
- batch_size
下一篇文章中我們將開始實現PointPillars的網絡部分。
input_dict`:{`voxels`, `num_points`, `coordinates`, `voxel_centers` , `points`, `sample_idx`, `batch_size`}總結
以上是生活随笔為你收集整理的从零实现一个3D目标检测算法(2):点云数据预处理的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 2017兴业信用卡优惠 6月份优惠活动大
- 下一篇: 复杂背景下的自动驾驶目标检测数据集