Skip to content
Snippets Groups Projects
Unverified Commit da09e7e5 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Support CPU mode for inference (#2385)


* add CPU only mode which can be activated during install

* fixed flake8 errors for too long lines, still have to deal with "import not at top of file"

* reversing changes in MinIoURandomCrop, that not relevant to the CPU_ONLY pull request

* moving the CPU_ONLY checks into deeper parts of the code

* completing previous commit

* using isort for imports sorting

* yapf fix

* followed @xvjiarui suggestions for the pull request

* use mmdet.CPU_ONLY and replace the "--cpu" flag in setup.py with automatic check of cuda

* make setup code cleaner

* back to original implementation of MinIoURandomCrop

* build all extensions with CUDA, if available

* fixed DC

* update doc

* fixed masked_conv2d_ext

* set warning once, update comment

Co-authored-by: default avatarYossi Biton <yossi.biton@alibaba-inc.com>
Co-authored-by: default avatarYossi Biton <yossibit10@gmail.com>
parent 8549d10c
No related branches found
No related tags found
No related merge requests found
......@@ -10,7 +10,8 @@ def parse_args():
parser = argparse.ArgumentParser(description='MMDetection webcam demo')
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--device', type=int, default=0, help='CUDA device id')
parser.add_argument(
'--device', type=str, default='cuda:0', help='CPU/CUDA device option')
parser.add_argument(
'--camera-id', type=int, default=0, help='camera device id')
parser.add_argument(
......@@ -22,8 +23,9 @@ def parse_args():
def main():
args = parse_args()
model = init_detector(
args.config, args.checkpoint, device=torch.device('cuda', args.device))
device = torch.device(args.device)
model = init_detector(args.config, args.checkpoint, device=device)
camera = cv2.VideoCapture(args.camera_id)
......
......@@ -60,6 +60,19 @@ you can install it before installing MMCV.
4. Some dependencies are optional. Simply running `pip install -v -e .` will only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`.
## Install with CPU only
The code can be built for CPU only environment (where CUDA isn't available).
In CPU mode you can run the demo/webcam_demo.py for example.
However some functionality is gone in this mode :
* Deformable Convolution
* Deformable ROI pooling
* CARAFE: Content-Aware ReAssembly of FEatures
* nms_cuda
* sigmoid_focal_loss_cuda
So if you try to run inference with a model containing deformable convolution you will get an error.
Note: We set `use_torchvision=True` on-the-fly in CPU mode for `RoIPool` and `RoIAlign`
### Another option: Docker Image
We provide a [Dockerfile](https://github.com/open-mmlab/mmdetection/blob/master/docker/Dockerfile) to build an image.
......
......@@ -11,6 +11,7 @@ from mmcv.runner import load_checkpoint
from mmdet.core import get_classes
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector
from mmdet.ops import RoIAlign, RoIPool
def init_detector(config, checkpoint=None, device='cuda:0'):
......@@ -37,6 +38,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
if 'CLASSES' in checkpoint['meta']:
model.CLASSES = checkpoint['meta']['CLASSES']
else:
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
......@@ -80,7 +82,20 @@ def inference_detector(model, img):
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
# Use torchvision ops for CPU mode instead
for m in model.modules():
if isinstance(m, (RoIPool, RoIAlign)):
# set use_torchvision on-the-fly
m.use_torchvision = True
warnings.warn('We set use_torchvision=True in CPU mode.')
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'][0].data
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
......
......@@ -47,8 +47,8 @@ int masked_col2im_forward(const at::Tensor col,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("masked_im2col_forward", &masked_im2col_forward_cuda,
"masked_im2col forward (CUDA)");
m.def("masked_col2im_forward", &masked_col2im_forward_cuda,
"masked_col2im forward (CUDA)");
m.def("masked_im2col_forward", &masked_im2col_forward,
"masked_im2col forward");
m.def("masked_col2im_forward", &masked_col2im_forward,
"masked_col2im forward");
}
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/vision.cpp
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#ifdef WITH_CUDA
#include <cuda_runtime_api.h>
int get_cudart_version() { return CUDART_VERSION; }
#endif
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment