Skip to content
Snippets Groups Projects
Unverified Commit 14d250f9 authored by jingege315's avatar jingege315 Committed by GitHub
Browse files

[Enhance]: use isinstance method to get loading pipeline (#4619)


* use isinstance method to get loading pipeline

* Fix isinstance error

* Add unit test

* Fix lint

* Fix lint

Co-authored-by: default avatarhhaAndroid <1286304229@qq.com>
parent e1599e7c
No related branches found
No related tags found
No related merge requests found
......@@ -4,6 +4,8 @@ import warnings
from mmcv.cnn import VGG
from mmcv.runner.hooks import HOOKS, Hook
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
from mmdet.models.dense_heads import GARPNHead, RPNHead
from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
......@@ -98,7 +100,10 @@ def get_loading_pipeline(pipeline):
"""
loading_pipeline_cfg = []
for cfg in pipeline:
if cfg['type'].startswith('Load'):
obj_cls = PIPELINES.get(cfg['type'])
# TODO:use more elegant way to distinguish loading modules
if obj_cls is not None and obj_cls in (LoadImageFromFile,
LoadAnnotations):
loading_pipeline_cfg.append(cfg)
assert len(loading_pipeline_cfg) == 2, \
'The data pipeline in your config file must include ' \
......
import pytest
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets import get_loading_pipeline, replace_ImageToTensor
def test_replace_ImageToTensor():
......@@ -59,3 +59,21 @@ def test_replace_ImageToTensor():
]
with pytest.warns(UserWarning):
assert expected_pipelines == replace_ImageToTensor(pipelines)
def test_get_loading_pipeline():
pipelines = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
expected_pipelines = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
]
assert expected_pipelines == \
get_loading_pipeline(pipelines)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment