Skip to content
Snippets Groups Projects
Commit 6b25743a authored by wangg12's avatar wangg12
Browse files

fix flake8

parent 7cbdbc78
No related branches found
No related tags found
No related merge requests found
......@@ -5,6 +5,7 @@ from .utils import to_tensor, random_scale, show_ann, get_dataset
from .concat_dataset import ConcatDataset
__all__ = [
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', 'ConcatDataset',
'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset'
'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler',
'ConcatDataset', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset'
]
......@@ -5,7 +5,7 @@ from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
class ConcatDataset(_ConcatDataset):
"""
Same as torch.utils.data.dataset.ConcatDataset, but
Same as torch.utils.data.dataset.ConcatDataset, but
concat the group flag for image aspect ratio.
"""
def __init__(self, datasets):
......@@ -13,7 +13,7 @@ class ConcatDataset(_ConcatDataset):
flag: Images with aspect ratio greater than 1 will be set as group 1,
otherwise group 0.
"""
super(ConcatDataset, self).__init__(datasets)
super(ConcatDataset, self).__init__(datasets)
if hasattr(datasets[0], 'flag'):
flags = []
for i in range(0, len(datasets)):
......@@ -27,4 +27,3 @@ class ConcatDataset(_ConcatDataset):
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return dataset_idx, sample_idx
......@@ -9,6 +9,7 @@ import numpy as np
from .concat_dataset import ConcatDataset
from .. import datasets
def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.
......@@ -72,7 +73,8 @@ def show_ann(coco, img, ann_info):
def get_dataset(data_cfg):
if isinstance(data_cfg['ann_file'], list) or isinstance(data_cfg['ann_file'], tuple):
if isinstance(data_cfg['ann_file'], list) or \
isinstance(data_cfg['ann_file'], tuple):
ann_files = data_cfg['ann_file']
dsets = []
for ann_file in ann_files:
......@@ -81,9 +83,9 @@ def get_dataset(data_cfg):
dset = obj_from_dict(data_info, datasets)
dsets.append(dset)
if len(dsets) > 1:
dset = ConcatDataset(dsets)
dset = ConcatDataset(dsets)
else:
dset = dsets[0]
else:
dset = obj_from_dict(data_cfg, datasets)
return dset
\ No newline at end of file
return dset
......@@ -2,7 +2,6 @@ from __future__ import division
import argparse
from mmcv import Config
from mmcv.runner import obj_from_dict
from mmdet import datasets, __version__
from mmdet.apis import (train_detector, init_dist, get_root_logger,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment