Skip to content
Snippets Groups Projects
Unverified Commit 21fef035 authored by Wenwei Zhang's avatar Wenwei Zhang Committed by GitHub
Browse files

Fix parrots compatibility issues (#4143)

* fix parrots compatibility

* add comments
parent 1cbe18b6
No related branches found
No related tags found
No related merge requests found
......@@ -5,11 +5,6 @@ import numpy as np
import torch
from mmcv.runner import load_checkpoint
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=v1.0.4')
def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config):
"""Prepare sample input and wrap model for ONNX export.
......@@ -51,6 +46,12 @@ def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config):
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
opset_version = 11
# put the import within the function thus it will not cause import error
# when not using this function
try:
from mmcv.onnx.symbolic import register_extra_symbolics
except ModuleNotFoundError:
raise NotImplementedError('please update mmcv to version>=v1.0.4')
register_extra_symbolics(opset_version)
return model, tensor_data
......
......@@ -103,8 +103,11 @@ class DistributedGroupSampler(Sampler):
if size > 0:
indice = np.where(self.flag == i)[0]
assert len(indice) == size
indice = indice[list(torch.randperm(int(size),
generator=g))].tolist()
# add .numpy() to avoid bug when selecting indice in parrots.
# TODO: check whether torch.randperm() can be replaced by
# numpy.random.permutation().
indice = indice[list(
torch.randperm(int(size), generator=g).numpy())].tolist()
extra = int(
math.ceil(
size * 1.0 / self.samples_per_gpu / self.num_replicas)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment