From 9baf0a8bb739ac3b5741769cbd4838e788e546cd Mon Sep 17 00:00:00 2001 From: wangg12 <guwang12@gmail.com> Date: Fri, 30 Nov 2018 10:28:52 +0800 Subject: [PATCH] fix some problems; support multiple proposal_files and img_prefixes --- .gitignore | 2 ++ mmdet/datasets/concat_dataset.py | 9 ------- mmdet/datasets/utils.py | 46 +++++++++++++++++++++++--------- tools/train.py | 5 ++-- 4 files changed, 38 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 01c47d6e..f189e1d5 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,5 @@ venv.bak/ mmdet/ops/nms/*.cpp mmdet/version.py data +.vscode +.idea diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py index 605d112f..e42b6098 100644 --- a/mmdet/datasets/concat_dataset.py +++ b/mmdet/datasets/concat_dataset.py @@ -1,4 +1,3 @@ -import bisect import numpy as np from torch.utils.data.dataset import ConcatDataset as _ConcatDataset @@ -19,11 +18,3 @@ class ConcatDataset(_ConcatDataset): for i in range(0, len(datasets)): flags.append(datasets[i].flag) self.flag = np.concatenate(flags) - - def get_idxs(self, idx): - dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) - if dataset_idx == 0: - sample_idx = idx - else: - sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] - return dataset_idx, sample_idx diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py index 92b4e5d5..6f6a0b51 100644 --- a/mmdet/datasets/utils.py +++ b/mmdet/datasets/utils.py @@ -1,5 +1,6 @@ -from collections import Sequence import copy +from collections import Sequence + import mmcv from mmcv.runner import obj_from_dict import torch @@ -73,19 +74,38 @@ def show_ann(coco, img, ann_info): def get_dataset(data_cfg): - if isinstance(data_cfg['ann_file'], list) or \ - isinstance(data_cfg['ann_file'], tuple): + if isinstance(data_cfg['ann_file'], (list, tuple)): ann_files = data_cfg['ann_file'] - dsets = [] - for ann_file in ann_files: - data_info = copy.deepcopy(data_cfg) - data_info['ann_file'] = ann_file - dset = obj_from_dict(data_info, datasets) - dsets.append(dset) - if len(dsets) > 1: - dset = ConcatDataset(dsets) + num_dset = len(ann_files) + else: + ann_files = [data_cfg['ann_file']] + num_dset = 1 + + if 'proposal_file' in data_cfg.keys(): + if isinstance(data_cfg['proposal_file'], (list, tuple)): + proposal_files = data_cfg['proposal_file'] else: - dset = dsets[0] + proposal_files = [data_cfg['proposal_file']] + else: + proposal_files = [None] * num_dset + assert len(proposal_files) == num_dset + + if isinstance(data_cfg['img_prefix'], (list, tuple)): + img_prefixes = data_cfg['img_prefix'] + else: + img_prefixes = [data_cfg['img_prefix']] * num_dset + assert len(img_prefixes) == num_dset + + dsets = [] + for i in range(num_dset): + data_info = copy.deepcopy(data_cfg) + data_info['ann_file'] = ann_files[i] + data_info['proposal_file'] = proposal_files[i] + data_info['img_prefix'] = img_prefixes[i] + dset = obj_from_dict(data_info, datasets) + dsets.append(dset) + if len(dsets) > 1: + dset = ConcatDataset(dsets) else: - dset = obj_from_dict(data_cfg, datasets) + dset = dsets[0] return dset diff --git a/tools/train.py b/tools/train.py index 49c46f05..bd47e66b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -3,7 +3,8 @@ from __future__ import division import argparse from mmcv import Config -from mmdet import datasets, __version__ +from mmdet import __version__ +from mmdet.datasets import get_dataset from mmdet.apis import (train_detector, init_dist, get_root_logger, set_random_seed) from mmdet.models import build_detector @@ -66,7 +67,7 @@ def main(): model = build_detector( cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) - train_dataset = datasets.get_dataset(cfg.data.train) + train_dataset = get_dataset(cfg.data.train) train_detector( model, train_dataset, -- GitLab