Skip to content
Snippets Groups Projects
Unverified Commit e27046d8 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

reimplement cityscapes (#2089)

* reimplement cityscapes

* fixed gt bbox mode

* convert cityscapes to coco style, add cityscapes eval

* add cityscapes convert script

* add doc

* Update INSTALL.md

* Update INSTALL.md

* update fater rcnn

* fix cityscapes eval

* support format only in cityscapes

* add docs

* remove redundancy

* resolve eval

* update cityscapes md

* more doc and rename

* update doc and cfg

* change to test set
parent 1b5c991f
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmdet
known_third_party = asynctest,cv2,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
known_third_party = asynctest,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
......@@ -3,25 +3,19 @@
- All baselines were trained using 8 GPU with a batch size of 8 (1 images per GPU) using the [linear scaling rule](https://arxiv.org/abs/1706.02677) to scale the learning rate.
- All models were trained on `cityscapes_train`, and tested on `cityscapes_val`.
- 1x training schedule indicates 64 epochs which corresponds to slightly less than the 24k iterations reported in the original schedule from the [Mask R-CNN paper](https://arxiv.org/abs/1703.06870)
- All pytorch-style pretrained backbones on ImageNet are from PyTorch model zoo.
## Baselines
Download links and more models with different backbones and training schemes will be added to the model zoo.
- COCO pre-trained weights are used to initialize.
- A conversion [script](../../tools/convert_datasets/cityscapes.py) is provided to convert Cityscapes into COCO format. Please refer to [INSTALL.md](../../docs/INSTALL.md#prepare-datasets) for details.
- `CityscapesDataset` implemented three evaluation methods. `bbox` and `segm` are standard COCO bbox/mask AP. `cityscapes` is the cityscapes dataset official evaluation, which may be slightly higher than COCO.
### Faster R-CNN
| Backbone | Style | Lr schd | Scale | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | Download |
| :-------------: | :-----: | :-----: | :---: | :------: | :-----------------: | :------------: | :----: | :------: |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | 0.345 | 8.8 | 36.0 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/faster_rcnn_r50_fpn_1x_city_20190727-7b9c0534.pth) |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | - | - | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes_20200227-362cfbbf.pth) |
### Mask R-CNN
| Backbone | Style | Lr schd | Scale | Mem (GB) | Train time (s/iter) | Inf time (fps) | box AP | mask AP | Download |
| :-------------: | :-----: | :-----: | :------: | :------: | :-----------------: | :------------: | :----: | :-----: | :------: |
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | 0.609 | 2.5 | 37.4 | 32.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/mask_rcnn_r50_fpn_1x_city_20190727-9b3c56a5.pth) |
**Notes:**
- In the original paper, the mask AP of Mask R-CNN R-50-FPN is 31.5.
| R-50-FPN | pytorch | 1x | 800-1024 | 4.9 | - | - | 41.9 | 37.1 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/models/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes_20200227-afe51d5a.pth) |
# model settings
model = dict(
type='FasterRCNN',
pretrained='modelzoo://resnet50',
pretrained=None,
backbone=dict(
type='ResNet',
depth=50,
......@@ -138,19 +138,19 @@ data = dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_train.json',
img_prefix=data_root + 'train/',
img_prefix=data_root + 'leftImg8bit/train/',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
img_prefix=data_root + 'leftImg8bit/val/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
'annotations/instancesonly_filtered_gtFine_test.json',
img_prefix=data_root + 'leftImg8bit/test/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
# optimizer
......@@ -163,7 +163,8 @@ lr_config = dict(
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[6])
# [7] yields higher performance than [6]
step=[7])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
......@@ -178,6 +179,7 @@ total_epochs = 8 # actual epoch = 8 * 8 = 64
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cityscapes/faster_rcnn_r50_fpn_1x_cityscapes'
load_from = None
# For better, more stable performance initialize from COCO
load_from = 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/faster_rcnn_r50_fpn_2x_20181010-443129e1.pth' # noqa
resume_from = None
workflow = [('train', 1)]
# model settings
model = dict(
type='MaskRCNN',
pretrained='modelzoo://resnet50',
pretrained=None,
backbone=dict(
type='ResNet',
depth=50,
......@@ -152,19 +152,19 @@ data = dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_train.json',
img_prefix=data_root + 'train/',
img_prefix=data_root + 'leftImg8bit/train/',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
img_prefix=data_root + 'leftImg8bit/val/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
img_prefix=data_root + 'val/',
'annotations/instancesonly_filtered_gtFine_test.json',
img_prefix=data_root + 'leftImg8bit/test/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric=['bbox', 'segm'])
# optimizer
......@@ -177,7 +177,8 @@ lr_config = dict(
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[6])
# [7] yields higher performance than [6]
step=[7])
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
......@@ -192,6 +193,7 @@ total_epochs = 8 # actual epoch = 8 * 8 = 64
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes'
load_from = None
# For better, more stable performance initialize from COCO
load_from = 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/mmdetection/models/mask_rcnn_r50_fpn_2x_20181010-41d35c05.pth' # noqa
resume_from = None
workflow = [('train', 1)]
......@@ -5,7 +5,7 @@ For installation instructions, please see [INSTALL.md](INSTALL.md).
## Inference with pretrained models
We provide testing scripts to evaluate a whole dataset (COCO, PASCAL VOC, etc.),
We provide testing scripts to evaluate a whole dataset (COCO, PASCAL VOC, Cityscapes, etc.),
and also some high-level apis for easier integration to other projects.
### Test a dataset
......@@ -26,7 +26,7 @@ python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [--out ${RESULT_FILE}] [-
Optional arguments:
- `RESULT_FILE`: Filename of the output results in pickle format. If not specified, the results will not be saved to a file.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `proposal_fast`, `proposal`, `bbox`, `segm` are available for COCO and `mAP`, `recall` for PASCAL VOC.
- `EVAL_METRICS`: Items to be evaluated on the results. Allowed values depend on the dataset, e.g., `proposal_fast`, `proposal`, `bbox`, `segm` are available for COCO, `mAP`, `recall` for PASCAL VOC. Cityscapes could be evaluated by `cityscapes` as well as all COCO metrics.
- `--show`: If specified, detection results will be plotted on the images and shown in a new window. It is only applicable to single GPU testing and used for debugging and visualization. Please make sure that GUI is available in your environment, otherwise you may encounter the error like `cannot connect to X server`.
If you would like to evaluate the dataset, do not specify `--show` at the same time.
......@@ -69,6 +69,16 @@ python tools/test.py configs/pascal_voc/faster_rcnn_r50_fpn_1x_voc.py \
You will get two json files `mask_rcnn_test-dev_results.bbox.json` and `mask_rcnn_test-dev_results.segm.json`.
5. Test Mask R-CNN on Cityscapes test with 8 GPUs, and generate the txt and png files to be submit to the official evaluation server.
```shell
./tools/dist_test.sh configs/cityscapes/mask_rcnn_r50_fpn_1x_cityscapes.py \
checkpoints/mask_rcnn_r50_fpn_1x_cityscapes_20200227-afe51d5a.pth \
8 --format_only --options "outfile_prefix=./mask_rcnn_cityscapes_test_results"
```
The generated png and txt would be under `./mask_rcnn_cityscapes_test_results` directory.
### Webcam demo
We provide a webcam demo to illustrate the results.
......
......@@ -87,20 +87,24 @@ mmdetection
│ │ ├── test2017
│ ├── cityscapes
│ │ ├── annotations
│ │ ├── train
│ │ ├── val
│ │ ├── leftImg8bit
│ │ │ ├── train
│ │ │ ├── val
│ │ ├── gtFine
│ │ │ ├── train
│ │ │ ├── val
│ ├── VOCdevkit
│ │ ├── VOC2007
│ │ ├── VOC2012
```
The cityscapes annotations have to be converted into the coco format using the [cityscapesScripts](https://github.com/mcordts/cityscapesScripts) toolbox.
We plan to provide an easy to use conversion script. For the moment we recommend following the instructions provided in the
[maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark/tree/master/maskrcnn_benchmark/data) toolbox. When using this script all images have to be moved into the same folder. On linux systems this can e.g. be done for the train images with:
The cityscapes annotations have to be converted into the coco format using `tools/convert_datasets/cityscapes.py`:
```shell
cd data/cityscapes/
mv train/*/* train/
pip install cityscapesscripts
python tools/convert_datasets/cityscapes.py ./data/cityscapes --nproc 8 --out_dir ./data/cityscapes/annotations
```
Current the config files in `cityscapes` use COCO pre-trained weights to initialize.
You could download the pre-trained models in advance if network is unavailable or slow, otherwise it would cause errors at the beginning of training.
### A from-scratch setup script
......
from .class_names import (coco_classes, dataset_aliases, get_classes,
imagenet_det_classes, imagenet_vid_classes,
voc_classes)
from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
get_classes, imagenet_det_classes,
imagenet_vid_classes, voc_classes)
from .eval_hooks import DistEvalHook
from .mean_ap import average_precision, eval_map, print_map_summary
from .recall import (eval_recalls, plot_iou_recall, plot_num_recall,
......@@ -8,7 +8,8 @@ from .recall import (eval_recalls, plot_iou_recall, plot_num_recall,
__all__ = [
'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
'coco_classes', 'dataset_aliases', 'get_classes', 'DistEvalHook',
'average_precision', 'eval_map', 'print_map_summary', 'eval_recalls',
'print_recall_summary', 'plot_num_recall', 'plot_iou_recall'
'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes',
'DistEvalHook', 'average_precision', 'eval_map', 'print_map_summary',
'eval_recalls', 'print_recall_summary', 'plot_num_recall',
'plot_iou_recall'
]
# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
# and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
import glob
import os
import os.path as osp
import tempfile
import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval
import cityscapesscripts.helpers.labels as CSLabels
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
from mmdet.utils import print_log
from .coco import CocoDataset
from .registry import DATASETS
......@@ -7,3 +22,252 @@ class CityscapesDataset(CocoDataset):
CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')
def _filter_imgs(self, min_size=32):
"""Filter images too small or without ground truths."""
valid_inds = []
ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
for i, img_info in enumerate(self.img_infos):
img_id = img_info['id']
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
ann_info = self.coco.loadAnns(ann_ids)
all_iscrowd = all([_['iscrowd'] for _ in ann_info])
if self.filter_empty_gt and (self.img_ids[i] not in ids_with_ann
or all_iscrowd):
continue
if min(img_info['width'], img_info['height']) >= min_size:
valid_inds.append(i)
return valid_inds
def _parse_ann_info(self, img_info, ann_info):
"""Parse bbox and mask annotation.
Args:
img_info (dict): Image info of an image.
ann_info (list[dict]): Annotation info of an image.
Returns:
dict: A dict containing the following keys: bboxes, bboxes_ignore,
labels, masks, seg_map.
"masks" are already decoded into binary masks.
"""
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_masks_ann = []
for i, ann in enumerate(ann_info):
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
if ann['area'] <= 0 or w < 1 or h < 1:
continue
bbox = [x1, y1, x1 + w - 1, y1 + h - 1]
if ann.get('iscrowd', False):
gt_bboxes_ignore.append(bbox)
else:
gt_bboxes.append(bbox)
gt_labels.append(self.cat2label[ann['category_id']])
gt_masks_ann.append(ann['segmentation'])
if gt_bboxes:
gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
gt_labels = np.array(gt_labels, dtype=np.int64)
else:
gt_bboxes = np.zeros((0, 4), dtype=np.float32)
gt_labels = np.array([], dtype=np.int64)
if gt_bboxes_ignore:
gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
else:
gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
ann = dict(
bboxes=gt_bboxes,
labels=gt_labels,
bboxes_ignore=gt_bboxes_ignore,
masks=gt_masks_ann,
seg_map=img_info['segm_file'])
return ann
def results2txt(self, results, outfile_prefix):
"""Dump the detection results to a txt file.
Args:
results (list[list | tuple | ndarray]): Testing results of the
dataset.
outfile_prefix (str): The filename prefix of the json files.
If the prefix is "somepath/xxx",
the txt files will be named "somepath/xxx.txt".
Returns:
list[str: str]: result txt files which contains corresponding
instance segmentation images.
"""
result_files = []
os.makedirs(outfile_prefix, exist_ok=True)
prog_bar = mmcv.ProgressBar(len(self))
for idx in range(len(self)):
result = results[idx]
filename = self.img_infos[idx]['filename']
basename = osp.splitext(osp.basename(filename))[0]
pred_txt = osp.join(outfile_prefix, basename + '_pred.txt')
bbox_result, segm_result = result
bboxes = np.vstack(bbox_result)
segms = mmcv.concat_list(segm_result)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
assert len(bboxes) == len(segms) == len(labels)
num_instances = len(bboxes)
prog_bar.update()
with open(pred_txt, 'w') as fout:
for i in range(num_instances):
pred_class = labels[i]
classes = self.CLASSES[pred_class]
class_id = CSLabels.name2label[classes].id
score = bboxes[i, -1]
mask = maskUtils.decode(segms[i]).astype(np.uint8)
png_filename = osp.join(
outfile_prefix,
basename + '_{}_{}.png'.format(i, classes))
mmcv.imwrite(mask, png_filename)
fout.write('{} {} {}\n'.format(
osp.basename(png_filename), class_id, score))
result_files.append(pred_txt)
return result_files
def format_results(self, results, txtfile_prefix=None):
"""Format the results to txt (standard format for Cityscapes evaluation).
Args:
results (list): Testing results of the dataset.
txtfile_prefix (str | None): The prefix of txt files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
Returns:
tuple: (result_files, tmp_dir), result_files is a dict containing
the json filepaths, tmp_dir is the temporal directory created
for saving txt/png files when txtfile_prefix is not specified.
"""
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: {} != {}'.
format(len(results), len(self)))
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
'The length of results is not equal to the dataset len: {} != {}'.
format(len(results), len(self)))
if txtfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
txtfile_prefix = osp.join(tmp_dir.name, 'results')
else:
tmp_dir = None
result_files = self.results2txt(results, txtfile_prefix)
return result_files, tmp_dir
def evaluate(self,
results,
metric='bbox',
logger=None,
outfile_prefix=None,
classwise=False,
proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)):
"""Evaluation in Cityscapes protocol.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
outfile_prefix (str | None):
classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thrs (Sequence[float]): IoU threshold used for evaluating
recalls. If set to a list, the average recall of all IoUs will
also be computed. Default: 0.5.
Returns:
dict[str: float]
"""
eval_results = dict()
metrics = metric.copy() if isinstance(metric, list) else [metric]
if 'cityscapes' in metrics:
eval_results.update(
self._evaluate_cityscapes(results, outfile_prefix, logger))
metrics.remove('cityscapes')
# left metrics are all coco metric
if len(metric) > 0:
# create CocoDataset with CityscapesDataset annotation
self_coco = CocoDataset(self.ann_file, self.pipeline.transforms,
self.data_root, self.img_prefix,
self.seg_prefix, self.proposal_file,
self.test_mode, self.filter_empty_gt)
eval_results.update(
self_coco.evaluate(results, metric, logger, outfile_prefix,
classwise, proposal_nums, iou_thrs))
return eval_results
def _evaluate_cityscapes(self, results, txtfile_prefix, logger):
msg = 'Evaluating in Cityscapes style'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
result_files, tmp_dir = self.format_results(results, txtfile_prefix)
if tmp_dir is None:
result_dir = osp.join(txtfile_prefix, 'results')
else:
result_dir = osp.join(tmp_dir.name, 'results')
eval_results = {}
print_log(
'Evaluating results under {} ...'.format(result_dir),
logger=logger)
# set global states in cityscapes evaluation API
CSEval.args.cityscapesPath = os.path.join(self.img_prefix, '../..')
CSEval.args.predictionPath = os.path.abspath(result_dir)
CSEval.args.predictionWalk = None
CSEval.args.JSONOutput = False
CSEval.args.colorized = False
CSEval.args.gtInstancesFile = os.path.join(result_dir,
'gtInstances.json')
CSEval.args.groundTruthSearch = os.path.join(
self.img_prefix.replace('leftImg8bit', 'gtFine'),
'*/*_gtFine_instanceIds.png')
groundTruthImgList = glob.glob(CSEval.args.groundTruthSearch)
assert len(groundTruthImgList), \
'Cannot find ground truth images in {}.'.format(
CSEval.args.groundTruthSearch)
predictionImgList = []
for gt in groundTruthImgList:
predictionImgList.append(CSEval.getPrediction(gt, CSEval.args))
CSEval_results = CSEval.evaluateImgLists(predictionImgList,
groundTruthImgList,
CSEval.args)['averages']
eval_results['mAP'] = CSEval_results['allAp']
eval_results['AP@50'] = CSEval_results['allAp50%']
if tmp_dir is not None:
tmp_dir.cleanup()
return eval_results
albumentations>=0.3.2
cityscapesscripts
imagecorruptions
import argparse
import glob
import os.path as osp
from shutil import copyfile, move
import cityscapesscripts.helpers.labels as CSLabels
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
def collect_files(img_dir, gt_dir):
suffix = 'leftImg8bit.png'
files = []
for img_file in glob.glob(osp.join(img_dir, '**/*.png')):
assert img_file.endswith(suffix), img_file
inst_file = gt_dir + img_file[
len(img_dir):-len(suffix)] + 'gtFine_instanceIds.png'
# Note that labelIds are not converted to trainId for seg map
segm_file = gt_dir + img_file[
len(img_dir):-len(suffix)] + 'gtFine_labelIds.png'
files.append((img_file, inst_file, segm_file))
assert len(files), 'No images found in {}'.format(img_dir)
print('Loaded {} images from {}'.format(len(files), img_dir))
return files
def collect_annotations(files, nproc=1):
print('Loading annotation images')
if nproc > 1:
images = mmcv.track_parallel_progress(
load_img_info, files, nproc=nproc)
else:
images = mmcv.track_progress(load_img_info, files)
return images
def load_img_info(files):
img_file, inst_file, segm_file = files
inst_img = mmcv.imread(inst_file, 'unchanged')
# ids < 24 are stuff labels (filtering them first is about 5% faster)
unique_inst_ids = np.unique(inst_img[inst_img >= 24])
anno_info = []
for inst_id in unique_inst_ids:
# For non-crowd annotations, inst_id // 1000 is the label_id
# Crowd annotations have <1000 instance ids
label_id = inst_id // 1000 if inst_id >= 1000 else inst_id
label = CSLabels.id2label[label_id]
if not label.hasInstances or label.ignoreInEval:
continue
category_id = label.id
iscrowd = int(inst_id < 1000)
mask = np.asarray(inst_img == inst_id, dtype=np.uint8, order='F')
mask_rle = maskUtils.encode(mask[:, :, None])[0]
area = maskUtils.area(mask_rle)
# convert to COCO style XYWH format
bbox = maskUtils.toBbox(mask_rle)
# for json encoding
mask_rle['counts'] = mask_rle['counts'].decode()
anno = dict(
iscrowd=iscrowd,
category_id=category_id,
bbox=bbox.tolist(),
area=area.tolist(),
segmentation=mask_rle)
anno_info.append(anno)
img_info = dict(
# remove img_prefix for filename
file_name=osp.basename(img_file),
height=inst_img.shape[0],
width=inst_img.shape[1],
anno_info=anno_info,
segm_file=osp.basename(segm_file))
return img_info
def cvt_annotations(image_infos, out_json_name):
out_json = dict()
img_id = 0
ann_id = 0
out_json['images'] = []
out_json['categories'] = []
out_json['annotations'] = []
for image_info in image_infos:
image_info['id'] = img_id
anno_infos = image_info.pop('anno_info')
out_json['images'].append(image_info)
for anno_info in anno_infos:
anno_info['image_id'] = img_id
anno_info['id'] = ann_id
out_json['annotations'].append(anno_info)
ann_id += 1
img_id += 1
for label in CSLabels.labels:
if label.hasInstances and not label.ignoreInEval:
cat = dict(id=label.id, name=label.name)
out_json['categories'].append(cat)
if len(out_json['annotations']) == 0:
out_json.pop('annotations')
mmcv.dump(out_json, out_json_name)
return out_json
def organize_files(files, target_dir, copy=True):
for img_file, _, segm_file in files:
if copy:
copyfile(img_file, osp.join(target_dir, osp.basename(img_file)))
copyfile(segm_file, osp.join(target_dir, osp.basename(segm_file)))
else:
move(img_file, osp.join(target_dir, osp.basename(img_file)))
move(segm_file, osp.join(target_dir, osp.basename(segm_file)))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert Cityscapes annotations to COCO format')
parser.add_argument('cityscapes_path', help='cityscapes data path')
parser.add_argument('--img_dir', default='leftImg8bit', type=str)
parser.add_argument('--gt_dir', default='gtFine', type=str)
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
parser.add_argument(
'--clean',
action='store_true',
help='whether delete img_dir and gt_dir')
args = parser.parse_args()
return args
def main():
args = parse_args()
cityscapes_path = args.cityscapes_path
out_dir = args.out_dir if args.out_dir else cityscapes_path
mmcv.mkdir_or_exist(out_dir)
img_dir = osp.join(cityscapes_path, args.img_dir)
gt_dir = osp.join(cityscapes_path, args.gt_dir)
set_name = dict(
train='instancesonly_filtered_gtFine_train.json',
val='instancesonly_filtered_gtFine_val.json',
test='instancesonly_filtered_gtFine_test.json')
for split, json_name in set_name.items():
print('Converting {} into {}'.format(split, json_name))
with mmcv.Timer(
print_tmpl='It tooks {}s to convert Cityscapes annotation'):
files = collect_files(
osp.join(img_dir, split), osp.join(gt_dir, split))
image_infos = collect_annotations(files, nproc=args.nproc)
cvt_annotations(image_infos, osp.join(out_dir, json_name))
organize_files(
files,
target_dir=osp.join(img_dir, split),
copy=not args.clean)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment