Skip to content
Snippets Groups Projects
Commit 9baf0a8b authored by wangg12's avatar wangg12
Browse files

fix some problems; support multiple proposal_files and img_prefixes

parent 6b25743a
No related branches found
No related tags found
No related merge requests found
......@@ -107,3 +107,5 @@ venv.bak/
mmdet/ops/nms/*.cpp
mmdet/version.py
data
.vscode
.idea
import bisect
import numpy as np
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
......@@ -19,11 +18,3 @@ class ConcatDataset(_ConcatDataset):
for i in range(0, len(datasets)):
flags.append(datasets[i].flag)
self.flag = np.concatenate(flags)
def get_idxs(self, idx):
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
from collections import Sequence
import copy
from collections import Sequence
import mmcv
from mmcv.runner import obj_from_dict
import torch
......@@ -73,19 +74,38 @@ def show_ann(coco, img, ann_info):
def get_dataset(data_cfg):
if isinstance(data_cfg['ann_file'], list) or \
isinstance(data_cfg['ann_file'], tuple):
if isinstance(data_cfg['ann_file'], (list, tuple)):
ann_files = data_cfg['ann_file']
dsets = []
for ann_file in ann_files:
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_file
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
num_dset = len(ann_files)
else:
ann_files = [data_cfg['ann_file']]
num_dset = 1
if 'proposal_file' in data_cfg.keys():
if isinstance(data_cfg['proposal_file'], (list, tuple)):
proposal_files = data_cfg['proposal_file']
else:
dset = dsets[0]
proposal_files = [data_cfg['proposal_file']]
else:
proposal_files = [None] * num_dset
assert len(proposal_files) == num_dset
if isinstance(data_cfg['img_prefix'], (list, tuple)):
img_prefixes = data_cfg['img_prefix']
else:
img_prefixes = [data_cfg['img_prefix']] * num_dset
assert len(img_prefixes) == num_dset
dsets = []
for i in range(num_dset):
data_info = copy.deepcopy(data_cfg)
data_info['ann_file'] = ann_files[i]
data_info['proposal_file'] = proposal_files[i]
data_info['img_prefix'] = img_prefixes[i]
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
else:
dset = obj_from_dict(data_cfg, datasets)
dset = dsets[0]
return dset
......@@ -3,7 +3,8 @@ from __future__ import division
import argparse
from mmcv import Config
from mmdet import datasets, __version__
from mmdet import __version__
from mmdet.datasets import get_dataset
from mmdet.apis import (train_detector, init_dist, get_root_logger,
set_random_seed)
from mmdet.models import build_detector
......@@ -66,7 +67,7 @@ def main():
model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
train_dataset = datasets.get_dataset(cfg.data.train)
train_dataset = get_dataset(cfg.data.train)
train_detector(
model,
train_dataset,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment