Skip to content
Snippets Groups Projects
Commit cd0d37cc authored by Cao Yuhang's avatar Cao Yuhang Committed by Kai Chen
Browse files

log reduced loss (#1782)

parent fb983fe6
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ import re
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict
......@@ -28,8 +29,12 @@ def parse_losses(losses):
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for name in log_vars:
log_vars[name] = log_vars[name].item()
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment