mmdet阅读笔记
mmdet
后續陸續增加源碼注釋
-- mmdetection.configs
注意: _base_里面的文件都是基礎的配置,后面的配置文件調用之后可以修改,以后面的為準
configs/base/dataset: 基礎數據的配置文件
configs/base/models: 基礎模型的配置文件
configs/base/schedules: 基礎超參數的配置文件
configs/base/default_runtime.py: 基礎實時配置文件,包括:模型保存間隔,dist后端配置....etc
configs/others: 上層配置文件,調用base里面的配置,然后針對不同模型不同情況重新封裝,實際調用以這個配置參數為準,基礎只是通用配置。
-- mmdetection.demo
/demo/all: 主要是前向計算測試文件
-- mmdetection.mmdet
/mmdet/apis: 訓練和前向計算實例化
/mmdet/core: anchor和bbox等操作具體實現,并被包裹到registry
/mmdet/datasets: 數據讀取處理函數
/datasets/pipelines: 數據增強具體實現和Compose
/datasets/samplers:
-- distributed_sampler.py: 重寫了distributed_sampler類,和torch原版一點沒變,僅僅改了名字。
-- group_sampler.py:
class GroupSampler(Sampler):
# samples_per_gpu: 使用的GPU數量
def __init__(self, dataset, samples_per_gpu=3):
assert hasattr(dataset, 'flag') # 數據中的變量,用來分配類別,在datasets/cumtom.py定義
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64)
self.group_sizes = np.bincount(self.flag)
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu # 不是整數取最大值
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes):
if size == 0:
continue
indice = np.where(self.flag == i)[0]
assert len(indice) == size
np.random.shuffle(indice) # random sample
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice) # 不能整除的額外數據 數量
indice = np.concatenate(
[indice, np.random.choice(indice, num_extra)]) # 不能整除的額外數據 使用前面數據隨機取出的數補充
indices.append(indice)
indices = np.concatenate(indices)
indices = [
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu)) # 分配到每個GPU
]
indices = np.concatenate(indices)
indices = indices.astype(np.int64).tolist()
assert len(indices) == self.num_samples
return iter(indices)
/torch/utils/data/dataset:
class ConcatDataset(Dataset):
def __init__(self, datasets):
self.cumulative_sizes = self.cumsum(self.datasets) # 疊加長度總和[len_1, len_1+len_2, len_1+len_2+len_3]
def __len__(self):
return self.cumulative_sizes[-1]#總長度
def __getitem__(self, idx):
# 反向索引
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
# 二分查找子數據集
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx] # 獲得 指定子數據集 的 指定位置數據
# 老版本名字已更改,可以更改數據集長度
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
/datasets/builder: 實例化數據相關任務:sample、dataloader、dataset
/datasets/dataset_wrappers.py: 重寫concatDataset、RepeatDataset上面已經詳細說明,增加數據類別平衡類(具體沒看)
/datasets/custom.py:
@DATASETS.register_module()
class CustomDataset(Dataset):
CLASSES = None #種類名稱,可以直接定義(常用直接類內定義),也可以外部傳入
# 讀取全部標簽,格式如下:
‘’‘
{
'filename': 'a.jpg',
'width': 1280,
'height': 720,
'ann':
{
'bboxes': <np.ndarray> (n, 4),
'labels': <np.ndarray> (n, ),
'bboxes_ignore': <np.ndarray> (k, 4), (optional field)
'labels_ignore': <np.ndarray> (k, 4) (optional field)
}
},
’‘’
def load_annotations(self, ann_file):
pass
# 暫不確定用途
def load_proposals(self, proposal_file):
pass
# 過濾不符合條件數據
def _filter_imgs(self, min_size=32):
pass
# 獲取單個train數據
def prepare_train_img(self, idx):
pass
# 獲取單個test數據
def prepare_test_img(self, idx):
# 獲得單個圖像標注信息
def get_ann_info(self, idx):
pass
# 隨機選擇數據,會使用_set_group_flag
def _rand_another(self, idx):
pass
# 按特定格式給圖像分類(原始使用長寬比)
def _set_group_flag(self):
pass
整個數據讀取流程比較清晰:
graph TD
A_1[準備特定格式label] --> A_2
A_2[讀取全部label]
--> A_3(過濾不合適label)
A_3 --> C{train/test}
C -->|train | D[讀取圖像信息+label信息]
C -->|test| E[和train類似]
D --> D_1{合適/不合適}
D_1 --> |不合適| D_2(隨機選取)
D_1 --> |合適| D_3(直接選取)
/mmdet/models: 模型實際實現函數
/mmdet.ops: 需要快速實現的操作,如:NMS、ROIPooling、ROIAlign....
/mmdet/utils: 一些輔助操作,環境變量和版本等
-- mmdetection.tests
/tests/all: 測試腳本,可以用來查看原理和測試
-- mmdetection.tools
/tools/all: 雜七雜八文件,包括:訓練+測試(僅是入口,實際操作在apis之內),數據轉換、計算MAC、轉換模型ONNX.....
/tools/train.py: 單機單卡
/tools/dist_train.py: 單機單多卡,使用distribution
/tools/slurm_train.py: 多機多卡
大致流程:
準備數據集,在mmdet/datasets
準備模型,在mmdet/models, loss函數在models里面實現
準備特殊函數,在/mmdet/core,一些mmdet沒有的操作
配置參數,在/configs, 基礎配置可選,后面的參數必須配置
訓練模型,在/mmdet/tools, 調用評估可在配置里設置
前向推理,在/demo
Already open...
...
總結
- 上一篇: 计算机网络(十),HTTP的关键问题
- 下一篇: add和addAll的区别