Skip to content
Snippets Groups Projects
Unverified Commit 17c0f8e3 authored by Wang Xinjiang's avatar Wang Xinjiang Committed by GitHub
Browse files

Add custom-defined hooks in train api (#3395)

* Add hooks

* change hooks to custom_hooks
parent 352cf7ff
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@
```
pip install "git+https://github.com/open-mmlab/cocoapi.git#subdirectory=lvis"
```
or
or
```
pip install -r requirements/optional.txt
```
......
......@@ -1696,4 +1696,4 @@
"outputs": []
}
]
}
\ No newline at end of file
}
......@@ -3,8 +3,9 @@ import random
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
build_optimizer)
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
OptimizerHook, build_optimizer)
from mmcv.utils import build_from_cfg
from mmdet.core import DistEvalHook, EvalHook, Fp16OptimizerHook
from mmdet.datasets import build_dataloader, build_dataset
......@@ -121,6 +122,14 @@ def train_detector(model,
eval_hook = DistEvalHook if distributed else EvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
# user-defined hooks
if cfg.get('custom_hooks', None):
for hook_cfg in cfg.hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
......
......@@ -147,9 +147,10 @@ class CornerHead(nn.Module):
feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None))
def _init_corner_kpt_layers(self):
"""Initialize corner keypoint layers. Including corner heatmap branch
and corner offset branch. Each branch has two parts: prefix `tl_` for
top-left and `br_` for bottom-right.
"""Initialize corner keypoint layers.
Including corner heatmap branch and corner offset branch. Each branch
has two parts: prefix `tl_` for top-left and `br_` for bottom-right.
"""
self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList()
self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList()
......@@ -184,9 +185,10 @@ class CornerHead(nn.Module):
in_channels=self.in_channels))
def _init_corner_emb_layers(self):
"""Initialize corner embedding layers. Only include corner embedding
branch with two parts: prefix `tl_` for top-left and `br_` for
bottom-right.
"""Initialize corner embedding layers.
Only include corner embedding branch with two parts: prefix `tl_` for
top-left and `br_` for bottom-right.
"""
self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList()
......@@ -202,6 +204,7 @@ class CornerHead(nn.Module):
def _init_layers(self):
"""Initialize layers for CornerHead.
Including two parts: corner keypoint layers and corner embedding layers
"""
self._init_corner_kpt_layers()
......
......@@ -2,8 +2,8 @@ matplotlib
numpy
# need older pillow until torchvision is fixed
Pillow<=6.2.2
pycocotools@git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools
six
terminaltables
torch>=1.3
torchvision
pycocotools@git+https://github.com/open-mmlab/cocoapi.git#subdirectory=pycocotools
......@@ -617,9 +617,7 @@ def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels):
def test_corner_head_loss():
"""
Tests corner head loss when truth is empty and non-empty
"""
"""Tests corner head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment