Skip to content
Snippets Groups Projects
Unverified Commit 01f1d42d authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Fix pytorch 1.7 imcompatibility issues (#4103)

* fix pytorch 1.7 imcompatibility issues

* remove useless files
parent 048c41a8
No related branches found
No related tags found
No related merge requests found
......@@ -125,18 +125,25 @@ async def async_inference_detector(model, img):
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
img (str | ndarray): Either image files or loaded images.
Returns:
Awaitable detection results.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# prepare data
if isinstance(img, np.ndarray):
# directly add img
data = dict(img=img)
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
else:
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
test_pipeline = Compose(cfg.data.test.pipeline)
# prepare data
data = dict(img_info=dict(filename=img), img_prefix=None)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
......
......@@ -73,20 +73,14 @@ def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
>>> [0, 10, 10, 19],
>>> [10, 10, 20, 20],
>>> ])
>>> bbox_overlaps(bboxes1, bboxes2)
tensor([[0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.0000],
[0.0000, 0.0000, 0.0000]])
>>> bbox_overlaps(bboxes1, bboxes2, mode='giou', eps=1e-7)
tensor([[0.5000, 0.0000, -0.5000],
[-0.2500, -0.0500, 1.0000],
[-0.8371, -0.8766, -0.8214]])
>>> overlaps = bbox_overlaps(bboxes1, bboxes2)
>>> assert overlaps.shape == (3, 3)
>>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
>>> assert overlaps.shape == (3, )
Example:
>>> empty = torch.FloatTensor([])
>>> nonempty = torch.FloatTensor([
>>> [0, 0, 10, 9],
>>> ])
>>> empty = torch.empty(0, 4)
>>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
>>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
>>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
>>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
......
......@@ -928,12 +928,12 @@ class InterpolateModule(nn.Module):
Any arguments you give it just get passed along for the ride.
"""
def __init__(self, *args, **kwdargs):
def __init__(self, *args, **kwargs):
super().__init__()
self.args = args
self.kwdargs = kwdargs
self.kwargs = kwargs
def forward(self, x):
"""Forward features from the upstream network."""
return F.interpolate(x, *self.args, **self.kwdargs)
return F.interpolate(x, *self.args, **self.kwargs)
......@@ -43,7 +43,7 @@ def accuracy(pred, target, topk=1, thresh=None):
correct = correct & (pred_value > thresh).t()
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / pred.size(0)))
return res[0] if return_single else res
......
......@@ -67,7 +67,8 @@ def sigmoid_focal_loss(pred,
"""
# Function.apply does not accept keyword arguments, so the decorator
# "weighted_loss" is not applicable
loss = _sigmoid_focal_loss(pred, target, gamma, alpha, None, 'none')
loss = _sigmoid_focal_loss(pred.contiguous(), target, gamma, alpha, None,
'none')
if weight is not None:
if weight.shape != loss.shape:
if weight.size(0) == loss.size(0):
......
from os.path import dirname, exists, join, relpath
import pytest
import torch
from mmcv.runner import build_optimizer
......@@ -75,142 +76,6 @@ def test_config_build_detector():
# _check_bbox_head(head_config, detector.bbox_head)
def test_config_data_pipeline():
"""Test whether the data pipeline is valid and can process corner cases.
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
from mmcv import Config
from mmdet.datasets.pipelines import Compose
import numpy as np
config_dpath = _get_config_directory()
print(f'Found config_dpath = {config_dpath}')
# Only tests a representative subset of configurations
# TODO: test pipelines using Albu, current Albu throw None given empty GT
config_names = [
'wider_face/ssd300_wider_face.py',
'pascal_voc/ssd300_voc0712.py',
'pascal_voc/ssd512_voc0712.py',
# 'albu_example/mask_rcnn_r50_fpn_1x.py',
'foveabox/fovea_align_r50_fpn_gn-head_mstrain_640-800_4x4_2x_coco.py',
'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py',
'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain_1x_coco.py',
'fp16/mask_rcnn_r50_fpn_fp16_1x_coco.py'
]
def dummy_masks(h, w, num_obj=3, mode='bitmap'):
assert mode in ('polygon', 'bitmap')
if mode == 'bitmap':
masks = np.random.randint(0, 2, (num_obj, h, w), dtype=np.uint8)
masks = BitmapMasks(masks, h, w)
else:
masks = []
for i in range(num_obj):
masks.append([])
masks[-1].append(
np.random.uniform(0, min(h - 1, w - 1), (8 + 4 * i, )))
masks[-1].append(
np.random.uniform(0, min(h - 1, w - 1), (10 + 4 * i, )))
masks = PolygonMasks(masks, h, w)
return masks
print(f'Using {len(config_names)} config files')
for config_fname in config_names:
config_fpath = join(config_dpath, config_fname)
config_mod = Config.fromfile(config_fpath)
# remove loading pipeline
loading_pipeline = config_mod.train_pipeline.pop(0)
loading_ann_pipeline = config_mod.train_pipeline.pop(0)
config_mod.test_pipeline.pop(0)
train_pipeline = Compose(config_mod.train_pipeline)
test_pipeline = Compose(config_mod.test_pipeline)
print(f'Building data pipeline, config_fpath = {config_fpath}')
print(f'Test training data pipeline: \n{train_pipeline!r}')
img = np.random.randint(0, 255, size=(888, 666, 3), dtype=np.uint8)
if loading_pipeline.get('to_float32', False):
img = img.astype(np.float32)
mode = 'bitmap' if loading_ann_pipeline.get('poly2mask',
True) else 'polygon'
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
gt_labels=np.array([1], dtype=np.int64),
gt_masks=dummy_masks(img.shape[0], img.shape[1], mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = train_pipeline(results)
assert output_results is not None
print(f'Test testing data pipeline: \n{test_pipeline!r}')
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
gt_labels=np.array([1], dtype=np.int64),
gt_masks=dummy_masks(img.shape[0], img.shape[1], mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = test_pipeline(results)
assert output_results is not None
# test empty GT
print('Test empty GT with training data pipeline: '
f'\n{train_pipeline!r}')
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.zeros((0, 4), dtype=np.float32),
gt_labels=np.array([], dtype=np.int64),
gt_masks=dummy_masks(
img.shape[0], img.shape[1], num_obj=0, mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = train_pipeline(results)
assert output_results is not None
print(f'Test empty GT with testing data pipeline: \n{test_pipeline!r}')
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.zeros((0, 4), dtype=np.float32),
gt_labels=np.array([], dtype=np.int64),
gt_masks=dummy_masks(
img.shape[0], img.shape[1], num_obj=0, mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = test_pipeline(results)
assert output_results is not None
def _check_roi_head(config, head):
# check consistency between head_config and roi_head
assert config['type'] == head.__class__.__name__
......@@ -368,3 +233,135 @@ def _check_anchorhead(config, head):
assert (config.in_channels == head.conv_reg.in_channels)
assert (head.conv_cls.out_channels == num_classes * head.num_anchors)
assert head.fc_reg.out_channels == 4 * head.num_anchors
# Only tests a representative subset of configurations
# TODO: test pipelines using Albu, current Albu throw None given empty GT
@pytest.mark.parametrize(
'config_rpath',
[
'wider_face/ssd300_wider_face.py',
'pascal_voc/ssd300_voc0712.py',
'pascal_voc/ssd512_voc0712.py',
# 'albu_example/mask_rcnn_r50_fpn_1x.py',
'foveabox/fovea_align_r50_fpn_gn-head_mstrain_640-800_4x4_2x_coco.py',
'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_1x_coco.py',
'mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain_1x_coco.py',
'fp16/mask_rcnn_r50_fpn_fp16_1x_coco.py'
])
def test_config_data_pipeline(config_rpath):
"""Test whether the data pipeline is valid and can process corner cases.
CommandLine:
xdoctest -m tests/test_config.py test_config_build_data_pipeline
"""
from mmcv import Config
from mmdet.datasets.pipelines import Compose
import numpy as np
config_dpath = _get_config_directory()
print(f'Found config_dpath = {config_dpath}')
def dummy_masks(h, w, num_obj=3, mode='bitmap'):
assert mode in ('polygon', 'bitmap')
if mode == 'bitmap':
masks = np.random.randint(0, 2, (num_obj, h, w), dtype=np.uint8)
masks = BitmapMasks(masks, h, w)
else:
masks = []
for i in range(num_obj):
masks.append([])
masks[-1].append(
np.random.uniform(0, min(h - 1, w - 1), (8 + 4 * i, )))
masks[-1].append(
np.random.uniform(0, min(h - 1, w - 1), (10 + 4 * i, )))
masks = PolygonMasks(masks, h, w)
return masks
config_fpath = join(config_dpath, config_rpath)
cfg = Config.fromfile(config_fpath)
# remove loading pipeline
loading_pipeline = cfg.train_pipeline.pop(0)
loading_ann_pipeline = cfg.train_pipeline.pop(0)
cfg.test_pipeline.pop(0)
train_pipeline = Compose(cfg.train_pipeline)
test_pipeline = Compose(cfg.test_pipeline)
print(f'Building data pipeline, config_fpath = {config_fpath}')
print(f'Test training data pipeline: \n{train_pipeline!r}')
img = np.random.randint(0, 255, size=(888, 666, 3), dtype=np.uint8)
if loading_pipeline.get('to_float32', False):
img = img.astype(np.float32)
mode = 'bitmap' if loading_ann_pipeline.get('poly2mask',
True) else 'polygon'
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
gt_labels=np.array([1], dtype=np.int64),
gt_masks=dummy_masks(img.shape[0], img.shape[1], mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = train_pipeline(results)
assert output_results is not None
print(f'Test testing data pipeline: \n{test_pipeline!r}')
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.array([[35.2, 11.7, 39.7, 15.7]], dtype=np.float32),
gt_labels=np.array([1], dtype=np.int64),
gt_masks=dummy_masks(img.shape[0], img.shape[1], mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = test_pipeline(results)
assert output_results is not None
# test empty GT
print('Test empty GT with training data pipeline: '
f'\n{train_pipeline!r}')
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.zeros((0, 4), dtype=np.float32),
gt_labels=np.array([], dtype=np.int64),
gt_masks=dummy_masks(img.shape[0], img.shape[1], num_obj=0, mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = train_pipeline(results)
assert output_results is not None
print(f'Test empty GT with testing data pipeline: \n{test_pipeline!r}')
results = dict(
filename='test_img.png',
ori_filename='test_img.png',
img=img,
img_shape=img.shape,
ori_shape=img.shape,
gt_bboxes=np.zeros((0, 4), dtype=np.float32),
gt_labels=np.array([], dtype=np.int64),
gt_masks=dummy_masks(img.shape[0], img.shape[1], num_obj=0, mode=mode),
)
results['img_fields'] = ['img']
results['bbox_fields'] = ['gt_bboxes']
results['mask_fields'] = ['gt_masks']
output_results = test_pipeline(results)
assert output_results is not None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment