Commit 56d40938 authored by suilin0432's avatar suilin0432
Browse files

.

parent 12cb4156
...@@ -27,7 +27,7 @@ from .cityscapes_panoptic import register_all_cityscapes_panoptic ...@@ -27,7 +27,7 @@ from .cityscapes_panoptic import register_all_cityscapes_panoptic
from .coco import load_sem_seg, register_coco_instances, register_coco_instances_wsl from .coco import load_sem_seg, register_coco_instances, register_coco_instances_wsl
from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated
from .lvis import get_lvis_instances_meta, register_lvis_instances from .lvis import get_lvis_instances_meta, register_lvis_instances
from .pascal_voc import register_pascal_voc, register_pascal_voc_wsl, register_pascal_voc_wsl_top1, register_pascal_voc_wsl_thres, register_pascal_voc_wsl_contain, register_pascal_voc_wsl_contain_total, register_pascal_voc_wsl_mist, register_pascal_voc_wsl_mist_contain, register_pascal_voc_wsl_contain_all, register_pascal_voc_wsl_contain_w2f, register_pascal_voc_wsl_oicr_contain, register_pascal_voc_wsl_oicr_contain_all, register_pascal_voc_wsl_w2f_overlap from .pascal_voc import register_pascal_voc, register_pascal_voc_wsl, register_pascal_voc_wsl_top1, register_pascal_voc_wsl_thres, register_pascal_voc_wsl_contain, register_pascal_voc_wsl_contain_total, register_pascal_voc_wsl_mist, register_pascal_voc_wsl_mist_contain, register_pascal_voc_wsl_contain_all, register_pascal_voc_wsl_contain_w2f, register_pascal_voc_wsl_oicr_contain, register_pascal_voc_wsl_oicr_contain_all, register_pascal_voc_wsl_w2f_overlap, register_pascal_voc_wsl_contain_all_adaptive
# ==== Predefined datasets and splits for COCO ========== # ==== Predefined datasets and splits for COCO ==========
...@@ -405,6 +405,16 @@ def register_all_pascal_voc_wsl_contain_all(root): ...@@ -405,6 +405,16 @@ def register_all_pascal_voc_wsl_contain_all(root):
register_pascal_voc_wsl_contain_all(name, os.path.join(root, dirname), split, year) register_pascal_voc_wsl_contain_all(name, os.path.join(root, dirname), split, year)
MetadataCatalog.get(name).evaluator_type = "pascal_voc" MetadataCatalog.get(name).evaluator_type = "pascal_voc"
def register_all_pascal_voc_wsl_contain_all_adaptive(root):
SPLITS = [
("voc_2007_train_wsl_contain_all_adaptive", "VOC2007", "train"),
("voc_2007_val_wsl_contain_all_adaptive", "VOC2007", "val")
]
for name, dirname, split in SPLITS:
year = 2007 if "2007" in name else 2012
register_pascal_voc_wsl_contain_all_adaptive(name, os.path.join(root, dirname), split, year)
MetadataCatalog.get(name).evaluator_type = "pascal_voc"
def register_all_pascal_voc_w2f(root): def register_all_pascal_voc_w2f(root):
SPLITS = [ SPLITS = [
("voc_2007_train_wsl_w2f", "VOC2007", "train"), ("voc_2007_train_wsl_w2f", "VOC2007", "train"),
...@@ -469,4 +479,5 @@ if __name__.endswith(".builtin"): ...@@ -469,4 +479,5 @@ if __name__.endswith(".builtin"):
register_all_pascal_voc_w2f_overlap(_root) register_all_pascal_voc_w2f_overlap(_root)
register_all_pascal_voc_wsl_oicr_contain(_root) register_all_pascal_voc_wsl_oicr_contain(_root)
register_all_pascal_voc_wsl_oicr_contain_all(_root) register_all_pascal_voc_wsl_oicr_contain_all(_root)
register_all_pascal_voc_wsl_contain_all_adaptive(_root)
register_all_ade20k(_root) register_all_ade20k(_root)
...@@ -12,7 +12,7 @@ from detectron2.data import DatasetCatalog, MetadataCatalog ...@@ -12,7 +12,7 @@ from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode from detectron2.structures import BoxMode
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
__all__ = ["load_voc_instances", "register_pascal_voc", "register_pascal_voc_wsl", "register_pascal_voc_wsl_top1", "register_pascal_voc_wsl_thres", "register_pascal_voc_wsl_mist", "register_pascal_voc_wsl_mist_contain", "register_pascal_voc_wsl_contain_all", "register_pascal_voc_wsl_contain_w2f", "register_pascal_voc_wsl_oicr_contain", "register_pascal_voc_wsl_oicr_contain_all", "register_pascal_voc_wsl_w2f_overlap"] __all__ = ["load_voc_instances", "register_pascal_voc", "register_pascal_voc_wsl", "register_pascal_voc_wsl_top1", "register_pascal_voc_wsl_thres", "register_pascal_voc_wsl_mist", "register_pascal_voc_wsl_mist_contain", "register_pascal_voc_wsl_contain_all", "register_pascal_voc_wsl_contain_w2f", "register_pascal_voc_wsl_oicr_contain", "register_pascal_voc_wsl_oicr_contain_all", "register_pascal_voc_wsl_w2f_overlap", "register_pascal_voc_wsl_contain_all_adaptive"]
# fmt: off # fmt: off
...@@ -547,6 +547,83 @@ def load_voc_instances_wsl_contain_all(dirname: str, split: str, class_names: Un ...@@ -547,6 +547,83 @@ def load_voc_instances_wsl_contain_all(dirname: str, split: str, class_names: Un
dicts.append(r) dicts.append(r)
return dicts return dicts
def load_voc_instances_wsl_contain_all_adaptive(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
# 获取 数据集对应划分(train, val, test) 图片 ids
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
fileids = np.loadtxt(f, dtype=np.str)
# 针对 single-input 的文件
# print("load from {}/single_voc07_wsl_{}_contain.json".format(dirname, split))
# annotation_wsl = json.load(open(
# "{}/single_voc07_wsl_{}_contain.json".format(dirname, split), "r"
# ))
# 获取 annotations, wsl 预测之后的结果会保存为 json 的格式
if "07" in dirname:
annotation_wsl = json.load(open(
"{}/voc07_wsl_{}_contain_all_adaptive.json".format(dirname, split), "r"
))
elif "12" in dirname:
annotation_wsl = json.load(open(
"{}/casd_voc12_wsl_{}_contain_all_adaptive.json".format(dirname, split), "r"
))
else:
assert False, "Wrong dirname: {}".format(dirname)
multi_class_labels = None
if "multi_label" in annotation_wsl:
multi_class_labels = annotation_wsl.pop("multi_label")
annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
dicts = []
for fileid in fileids:
anno = annotation_wsl[str(int(fileid))]
jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
anno_file = os.path.join(annotation_dirname, fileid + ".xml")
if not os.path.isfile(anno_file):
with Image.open(jpeg_file) as img:
width, height = img.size
r = {"file_name": jpeg_file, "image_id": fileid, "height": height, "width": width}
instances = []
for obj in anno:
bbox = obj["bbox"]
bbox = [int(i) for i in bbox] # 因为 predict 出来的 bbox 是float, 要转化为 int list
category_id = obj["category_id"] - 1 # 因为保存统计时将 index + 1 了从而方便 TIDE 统计了, 因此这里需要 - 1
instances.append(
{
"category_id": category_id, "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS
}
)
r["annotations"] = instances
if multi_class_labels is not None:
r["multi_label"] = multi_class_labels[str(int(fileid))]
dicts.append(r)
continue
with PathManager.open(anno_file) as f:
tree = ET.parse(f)
r = {
"file_name": jpeg_file,
"image_id": fileid,
"height": int(tree.findall("./size/height")[0].text),
"width": int(tree.findall("./size/width")[0].text),
}
instances = []
# 这里是从 annotation_wsl 中进行 gt 信息的提取, 而不是 从 anno file 中提取真正的 gt 信息出来
for obj in anno:
bbox = obj["bbox"]
bbox = [int(i) for i in bbox]
category_id = obj["category_id"] - 1
instances.append(
{
"category_id": category_id, "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS
}
)
r["annotations"] = instances
if multi_class_labels is not None:
r["multi_label"] = multi_class_labels[str(int(fileid))]
dicts.append(r)
return dicts
def load_voc_instances_wsl_w2f(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): def load_voc_instances_wsl_w2f(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
# 获取 数据集对应划分(train, val, test) 图片 ids # 获取 数据集对应划分(train, val, test) 图片 ids
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
...@@ -1027,6 +1104,12 @@ def register_pascal_voc_wsl_contain_all(name, dirname, split, year, class_names= ...@@ -1027,6 +1104,12 @@ def register_pascal_voc_wsl_contain_all(name, dirname, split, year, class_names=
thing_classes=list(class_names), dirname=dirname, year=year, split=split thing_classes=list(class_names), dirname=dirname, year=year, split=split
) )
def register_pascal_voc_wsl_contain_all_adaptive(name, dirname, split, year, class_names=CLASS_NAMES):
DatasetCatalog.register(name, lambda: load_voc_instances_wsl_contain_all_adaptive(dirname, split, class_names))
MetadataCatalog.get(name).set(
thing_classes=list(class_names), dirname=dirname, year=year, split=split
)
def register_pascal_voc_wsl_contain_w2f(name, dirname, split, year, class_names=CLASS_NAMES): def register_pascal_voc_wsl_contain_w2f(name, dirname, split, year, class_names=CLASS_NAMES):
DatasetCatalog.register(name, lambda: load_voc_instances_wsl_w2f(dirname, split, class_names)) DatasetCatalog.register(name, lambda: load_voc_instances_wsl_w2f(dirname, split, class_names))
MetadataCatalog.get(name).set( MetadataCatalog.get(name).set(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment