Skip to content
Snippets Groups Projects
Unverified Commit 42420f66 authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Add unit test for batch inference (#4526)

* Update inference and add unit test

* Fix lint

* Fix abnormal resource usage

* Update unit test
parent 62b8ae90
No related branches found
No related tags found
No related merge requests found
......@@ -80,30 +80,29 @@ class LoadImage(object):
return results
def inference_detector(model, imglist):
def inference_detector(model, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
If imgs is a str, a generator will be returned, otherwise return the
detection results directly.
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
is_batch = False
if isinstance(imglist, list):
if isinstance(imgs, (list, tuple)):
is_batch = True
else:
imglist = [imglist]
imgs = [imgs]
is_batch = False
cfg = model.cfg
device = next(model.parameters()).device # model device
results = []
if isinstance(imglist[0], np.ndarray):
if isinstance(imgs[0], np.ndarray):
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
......@@ -111,8 +110,8 @@ def inference_detector(model, imglist):
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
datalist = []
for img in imglist:
datas = []
for img in imgs:
# prepare data
if isinstance(img, np.ndarray):
# directly add img
......@@ -122,9 +121,9 @@ def inference_detector(model, imglist):
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
data = test_pipeline(data)
datalist.append(data)
datas.append(data)
data = collate(datalist, samples_per_gpu=len(imglist))
data = collate(datas, samples_per_gpu=len(imgs))
# just get the actual data from DataContainer
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
data['img'] = [img.data[0] for img in data['img']]
......@@ -143,8 +142,8 @@ def inference_detector(model, imglist):
if not is_batch:
return results[0]
return results
else:
return results
async def async_inference_detector(model, img):
......
......@@ -408,7 +408,7 @@ def test_yolact_forward():
from mmdet.models import build_detector
detector = build_detector(model)
input_shape = (1, 3, 550, 550)
input_shape = (1, 3, 100, 100)
mm_inputs = _demo_mm_inputs(input_shape)
imgs = mm_inputs.pop('imgs')
......@@ -447,7 +447,7 @@ def test_detr_forward():
from mmdet.models import build_detector
detector = build_detector(model)
input_shape = (1, 3, 550, 550)
input_shape = (1, 3, 100, 100)
mm_inputs = _demo_mm_inputs(input_shape)
imgs = mm_inputs.pop('imgs')
......@@ -493,3 +493,61 @@ def test_detr_forward():
rescale=True,
return_loss=False)
batch_results.append(result)
def test_inference_detector():
from mmdet.apis import inference_detector
from mmdet.models import build_detector
from mmcv import ConfigDict
# small RetinaNet
num_class = 3
model_dict = dict(
type='RetinaNet',
pretrained=None,
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch'),
neck=None,
bbox_head=dict(
type='RetinaHead',
num_classes=num_class,
in_channels=512,
stacked_convs=1,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
octave_base_scale=4,
scales_per_octave=3,
ratios=[0.5],
strides=[32]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0]),
),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
rng = np.random.RandomState(0)
img1 = rng.rand(100, 100, 3)
img2 = rng.rand(100, 100, 3)
model = build_detector(ConfigDict(model_dict))
config = _get_config_module('retinanet/retinanet_r50_fpn_1x_coco.py')
model.cfg = config
# test single image
result = inference_detector(model, img1)
assert len(result) == num_class
# test multiple image
result = inference_detector(model, [img1, img2])
assert len(result) == 2 and len(result[0]) == num_class
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment