From 9baf0a8bb739ac3b5741769cbd4838e788e546cd Mon Sep 17 00:00:00 2001
From: wangg12 <guwang12@gmail.com>
Date: Fri, 30 Nov 2018 10:28:52 +0800
Subject: [PATCH] fix some problems; support multiple proposal_files and
 img_prefixes

---
 .gitignore                       |  2 ++
 mmdet/datasets/concat_dataset.py |  9 -------
 mmdet/datasets/utils.py          | 46 +++++++++++++++++++++++---------
 tools/train.py                   |  5 ++--
 4 files changed, 38 insertions(+), 24 deletions(-)

diff --git a/.gitignore b/.gitignore
index 01c47d6e..f189e1d5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -107,3 +107,5 @@ venv.bak/
 mmdet/ops/nms/*.cpp
 mmdet/version.py
 data
+.vscode
+.idea
diff --git a/mmdet/datasets/concat_dataset.py b/mmdet/datasets/concat_dataset.py
index 605d112f..e42b6098 100644
--- a/mmdet/datasets/concat_dataset.py
+++ b/mmdet/datasets/concat_dataset.py
@@ -1,4 +1,3 @@
-import bisect
 import numpy as np
 from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
 
@@ -19,11 +18,3 @@ class ConcatDataset(_ConcatDataset):
             for i in range(0, len(datasets)):
                 flags.append(datasets[i].flag)
             self.flag = np.concatenate(flags)
-
-    def get_idxs(self, idx):
-        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
-        if dataset_idx == 0:
-            sample_idx = idx
-        else:
-            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
-        return dataset_idx, sample_idx
diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py
index 92b4e5d5..6f6a0b51 100644
--- a/mmdet/datasets/utils.py
+++ b/mmdet/datasets/utils.py
@@ -1,5 +1,6 @@
-from collections import Sequence
 import copy
+from collections import Sequence
+
 import mmcv
 from mmcv.runner import obj_from_dict
 import torch
@@ -73,19 +74,38 @@ def show_ann(coco, img, ann_info):
 
 
 def get_dataset(data_cfg):
-    if isinstance(data_cfg['ann_file'], list) or \
-            isinstance(data_cfg['ann_file'], tuple):
+    if isinstance(data_cfg['ann_file'], (list, tuple)):
         ann_files = data_cfg['ann_file']
-        dsets = []
-        for ann_file in ann_files:
-            data_info = copy.deepcopy(data_cfg)
-            data_info['ann_file'] = ann_file
-            dset = obj_from_dict(data_info, datasets)
-            dsets.append(dset)
-        if len(dsets) > 1:
-            dset = ConcatDataset(dsets)
+        num_dset = len(ann_files)
+    else:
+        ann_files = [data_cfg['ann_file']]
+        num_dset = 1
+
+    if 'proposal_file' in data_cfg.keys():
+        if isinstance(data_cfg['proposal_file'], (list, tuple)):
+            proposal_files = data_cfg['proposal_file']
         else:
-            dset = dsets[0]
+            proposal_files = [data_cfg['proposal_file']]
+    else:
+        proposal_files = [None] * num_dset
+    assert len(proposal_files) == num_dset
+
+    if isinstance(data_cfg['img_prefix'], (list, tuple)):
+        img_prefixes = data_cfg['img_prefix']
+    else:
+        img_prefixes = [data_cfg['img_prefix']] * num_dset
+    assert len(img_prefixes) == num_dset
+
+    dsets = []
+    for i in range(num_dset):
+        data_info = copy.deepcopy(data_cfg)
+        data_info['ann_file'] = ann_files[i]
+        data_info['proposal_file'] = proposal_files[i]
+        data_info['img_prefix'] = img_prefixes[i]
+        dset = obj_from_dict(data_info, datasets)
+        dsets.append(dset)
+    if len(dsets) > 1:
+        dset = ConcatDataset(dsets)
     else:
-        dset = obj_from_dict(data_cfg, datasets)
+        dset = dsets[0]
     return dset
diff --git a/tools/train.py b/tools/train.py
index 49c46f05..bd47e66b 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -3,7 +3,8 @@ from __future__ import division
 import argparse
 from mmcv import Config
 
-from mmdet import datasets, __version__
+from mmdet import __version__
+from mmdet.datasets import get_dataset
 from mmdet.apis import (train_detector, init_dist, get_root_logger,
                         set_random_seed)
 from mmdet.models import build_detector
@@ -66,7 +67,7 @@ def main():
 
     model = build_detector(
         cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
-    train_dataset = datasets.get_dataset(cfg.data.train)
+    train_dataset = get_dataset(cfg.data.train)
     train_detector(
         model,
         train_dataset,
-- 
GitLab