From b5d62ef9ec8470d35025133d80e72b8f6f186db3 Mon Sep 17 00:00:00 2001
From: Jon Crall <erotemic@gmail.com>
Date: Tue, 21 Jan 2020 22:17:36 -0500
Subject: [PATCH] Reorganize requirements, make albumentations optional (#1969)

* reorganize requirements, make albumentations optional

* fix flake8 error

* force older version of Pillow until torchvision is fixed

* make imagecorruptions optional and update INSTALL.dm

* update INSTALL.md

* Add note about pillow version

* Add build requirements to install instructions
---
 .travis.yml                                   |  1 -
 docs/INSTALL.md                               |  8 +-
 mmdet/datasets/pipelines/transforms.py        | 21 +++-
 requirements.txt                              | 14 +--
 requirements/build.txt                        |  4 +
 requirements/optional.txt                     |  2 +
 requirements/runtime.txt                      | 11 +++
 .../tests.txt                                 | 11 ++-
 setup.py                                      | 96 +++++++++++++++++--
 9 files changed, 138 insertions(+), 30 deletions(-)
 create mode 100644 requirements/build.txt
 create mode 100644 requirements/optional.txt
 create mode 100644 requirements/runtime.txt
 rename tests/requirements.txt => requirements/tests.txt (87%)
 mode change 100644 => 100755 setup.py

diff --git a/.travis.yml b/.travis.yml
index f9dba1f4..c2956e73 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -27,7 +27,6 @@ install:
   - pip install Pillow==6.2.2  # remove this line when torchvision>=0.5
   - pip install Cython torch==1.2 torchvision==0.4.0  # TODO: fix CI for pytorch>1.2
   - pip install -r requirements.txt
-  - pip install -r tests/requirements.txt
 
 before_script:
   - flake8 .
diff --git a/docs/INSTALL.md b/docs/INSTALL.md
index 18bbe420..d57be4bc 100644
--- a/docs/INSTALL.md
+++ b/docs/INSTALL.md
@@ -39,11 +39,11 @@ git clone https://github.com/open-mmlab/mmdetection.git
 cd mmdetection
 ```
 
-d. Install mmdetection (other dependencies will be installed automatically).
+d. Install build requirements and then install mmdetection.
 
 ```shell
-pip install mmcv
-python setup.py develop  # or "pip install -v -e ."
+pip install -r requirements/build.txt
+pip install -v -e .  # or "python setup.py develop"
 ```
 
 Note:
@@ -56,6 +56,8 @@ It is recommended that you run step d each time you pull some updates from githu
 3. If you would like to use `opencv-python-headless` instead of `opencv-python`,
 you can install it before installing MMCV.
 
+4. Some dependencies are optional. Simply running `pip install -v -e .` will only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`.
+
 ### Another option: Docker Image
 
 We provide a [Dockerfile](https://github.com/open-mmlab/mmdetection/blob/master/docker/Dockerfile) to build an image.
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
index 8a5c7280..d5a8a144 100644
--- a/mmdet/datasets/pipelines/transforms.py
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -1,15 +1,24 @@
 import inspect
 
-import albumentations
 import mmcv
 import numpy as np
-from albumentations import Compose
-from imagecorruptions import corrupt
 from numpy import random
 
 from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
 from ..registry import PIPELINES
 
+try:
+    from imagecorruptions import corrupt
+except ImportError:
+    corrupt = None
+
+try:
+    import albumentations
+    from albumentations import Compose
+except ImportError:
+    albumentations = None
+    Compose = None
+
 
 @PIPELINES.register_module
 class Resize(object):
@@ -695,6 +704,8 @@ class Corrupt(object):
         self.severity = severity
 
     def __call__(self, results):
+        if corrupt is None:
+            raise RuntimeError('imagecorruptions is not installed')
         results['img'] = corrupt(
             results['img'].astype(np.uint8),
             corruption_name=self.corruption,
@@ -728,6 +739,8 @@ class Albu(object):
         skip_img_without_anno (bool): whether to skip the image
                                       if no ann left after aug
         """
+        if Compose is None:
+            raise RuntimeError('albumentations is not installed')
 
         self.transforms = transforms
         self.filter_lost_elements = False
@@ -771,6 +784,8 @@ class Albu(object):
 
         obj_type = args.pop("type")
         if mmcv.is_str(obj_type):
+            if albumentations is None:
+                raise RuntimeError('albumentations is not installed')
             obj_cls = getattr(albumentations, obj_type)
         elif inspect.isclass(obj_type):
             obj_cls = obj_type
diff --git a/requirements.txt b/requirements.txt
index 01cfcded..52ee8f55 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,4 @@
-albumentations>=0.3.2
-imagecorruptions
-matplotlib
-mmcv>=0.2.16
-numpy
-pycocotools
-six
-terminaltables
-torch>=1.1
-torchvision
+-r requirements/runtime.txt
+-r requirements/optional.txt
+-r requirements/tests.txt
+-r requirements/build.txt
diff --git a/requirements/build.txt b/requirements/build.txt
new file mode 100644
index 00000000..a24ea0c6
--- /dev/null
+++ b/requirements/build.txt
@@ -0,0 +1,4 @@
+# These must be installed before building mmdetection
+cython
+numpy
+torch>=1.1
diff --git a/requirements/optional.txt b/requirements/optional.txt
new file mode 100644
index 00000000..eb36729e
--- /dev/null
+++ b/requirements/optional.txt
@@ -0,0 +1,2 @@
+albumentations>=0.3.2
+imagecorruptions
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
new file mode 100644
index 00000000..eb66e6d0
--- /dev/null
+++ b/requirements/runtime.txt
@@ -0,0 +1,11 @@
+matplotlib
+mmcv>=0.2.15
+numpy
+pycocotools
+six
+terminaltables
+torch>=1.1
+torchvision
+
+# need older pillow until torchvision is fixed
+Pillow<=6.2.2
diff --git a/tests/requirements.txt b/requirements/tests.txt
similarity index 87%
rename from tests/requirements.txt
rename to requirements/tests.txt
index 6f8c22db..d45e5409 100644
--- a/tests/requirements.txt
+++ b/requirements/tests.txt
@@ -1,10 +1,11 @@
-isort
+asynctest
+codecov
 flake8
-yapf
+isort
+pytest 
 pytest-cov
-codecov
+pytest-runner
 xdoctest >= 0.10.0
-asynctest
-
+yapf
 # Note: used for kwarray.group_items, this may be ported to mmcv in the future.
 kwarray
diff --git a/setup.py b/setup.py
old mode 100644
new mode 100755
index 54d58777..f5ace596
--- a/setup.py
+++ b/setup.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
 import os
 import platform
 import subprocess
@@ -131,11 +133,83 @@ def make_cython_ext(name, module, sources):
     return extension
 
 
-def get_requirements(filename='requirements.txt'):
-    here = os.path.dirname(os.path.realpath(__file__))
-    with open(os.path.join(here, filename), 'r') as f:
-        requires = [line.replace('\n', '') for line in f.readlines()]
-    return requires
+def parse_requirements(fname='requirements.txt', with_version=True):
+    """
+    Parse the package dependencies listed in a requirements file but strips
+    specific versioning information.
+
+    Args:
+        fname (str): path to requirements file
+        with_version (bool, default=False): if True include version specs
+
+    Returns:
+        List[str]: list of requirements items
+
+    CommandLine:
+        python -c "import setup; print(setup.parse_requirements())"
+    """
+    import sys
+    from os.path import exists
+    import re
+    require_fpath = fname
+
+    def parse_line(line):
+        """
+        Parse information from a line in a requirements text file
+        """
+        if line.startswith('-r '):
+            # Allow specifying requirements in other files
+            target = line.split(' ')[1]
+            for info in parse_require_file(target):
+                yield info
+        else:
+            info = {'line': line}
+            if line.startswith('-e '):
+                info['package'] = line.split('#egg=')[1]
+            else:
+                # Remove versioning from the package
+                pat = '(' + '|'.join(['>=', '==', '>']) + ')'
+                parts = re.split(pat, line, maxsplit=1)
+                parts = [p.strip() for p in parts]
+
+                info['package'] = parts[0]
+                if len(parts) > 1:
+                    op, rest = parts[1:]
+                    if ';' in rest:
+                        # Handle platform specific dependencies
+                        # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
+                        version, platform_deps = map(str.strip,
+                                                     rest.split(';'))
+                        info['platform_deps'] = platform_deps
+                    else:
+                        version = rest  # NOQA
+                    info['version'] = (op, version)
+            yield info
+
+    def parse_require_file(fpath):
+        with open(fpath, 'r') as f:
+            for line in f.readlines():
+                line = line.strip()
+                if line and not line.startswith('#'):
+                    for info in parse_line(line):
+                        yield info
+
+    def gen_packages_items():
+        if exists(require_fpath):
+            for info in parse_require_file(require_fpath):
+                parts = [info['package']]
+                if with_version and 'version' in info:
+                    parts.extend(info['version'])
+                if not sys.version.startswith('3.4'):
+                    # apparently package_deps are broken in 3.4
+                    platform_deps = info.get('platform_deps')
+                    if platform_deps is not None:
+                        parts.append(';' + platform_deps)
+                item = ''.join(parts)
+                yield item
+
+    packages = list(gen_packages_items())
+    return packages
 
 
 if __name__ == '__main__':
@@ -161,9 +235,15 @@ if __name__ == '__main__':
             'Programming Language :: Python :: 3.7',
         ],
         license='Apache License 2.0',
-        setup_requires=['pytest-runner', 'cython', 'numpy'],
-        tests_require=['pytest', 'xdoctest', 'asynctest'],
-        install_requires=get_requirements(),
+        setup_requires=parse_requirements('requirements/build.txt'),
+        tests_require=parse_requirements('requirements/tests.txt'),
+        install_requires=parse_requirements('requirements/runtime.txt'),
+        extras_require={
+            'all': parse_requirements('requirements.txt'),
+            'tests': parse_requirements('requirements/tests.txt'),
+            'build': parse_requirements('requirements/build.txt'),
+            'optional': parse_requirements('requirements/optional.txt'),
+        },
         ext_modules=[
             make_cuda_ext(
                 name='compiling_info',
-- 
GitLab