Skip to content
Snippets Groups Projects
Unverified Commit b3f1e05c authored by yuzhj's avatar yuzhj Committed by GitHub
Browse files

fix rpn transforming bug in two stage networks (#3754)

* fix rpn transforming bug in two_stage
parent bb514fa0
No related branches found
No related tags found
No related merge requests found
......@@ -57,7 +57,7 @@ def reorder_cls_channel(val, num_classes=81):
# fc_cls
elif out_channels == num_classes:
new_val = torch.cat((val[1:], val[:1]), dim=0)
# agnostic | retina_cls | rpn_cls
# agnostic | retina_cls
else:
new_val = val
......@@ -89,7 +89,7 @@ def truncate_cls_channel(val, num_classes=81):
def truncate_reg_channel(val, num_classes=81):
# bias
if val.dim() == 1:
# fc_reg|rpn_reg
# fc_reg
if val.size(0) % num_classes == 0:
new_val = val.reshape(num_classes, -1)[:num_classes - 1]
new_val = new_val.reshape(-1)
......@@ -99,7 +99,7 @@ def truncate_reg_channel(val, num_classes=81):
# weight
else:
out_channels, in_channels = val.shape[:2]
# fc_reg|rpn_reg
# fc_reg
if out_channels % num_classes == 0:
new_val = val.reshape(num_classes, -1, in_channels,
*val.shape[2:])[1:]
......@@ -137,14 +137,14 @@ def convert(in_file, out_file, num_classes):
# classification
m = re.search(
r'(conv_cls|retina_cls|rpn_cls|fc_cls|fcos_cls|'
r'(conv_cls|retina_cls|fc_cls|fcos_cls|'
r'fovea_cls).(weight|bias)', new_key)
if m is not None:
print(f'reorder cls channels of {new_key}')
new_val = reorder_cls_channel(val, num_classes)
# regression
m = re.search(r'(fc_reg|rpn_reg).(weight|bias)', new_key)
m = re.search(r'(fc_reg).(weight|bias)', new_key)
if m is not None and not reg_cls_agnostic:
print(f'truncate regression channels of {new_key}')
new_val = truncate_reg_channel(val, num_classes)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment