提交 67a0ffcb 编辑于 作者: Thor Johnsen's avatar Thor Johnsen
浏览文件

More detailed output

上级 bc9114c9
......@@ -4,6 +4,7 @@ from maskrcnn_benchmark.modeling.backbone.resnet import Bottleneck
from maskrcnn_benchmark.layers.nhwc import nhwc_to_nchw_transform, nchw_to_nhwc_transform
from maskrcnn_benchmark.layers.nhwc.batch_norm import FrozenBatchNorm2d_NHWC
from apex.contrib.bottleneck import Bottleneck as FastBottleneck
from apex.contrib.bottleneck import SpatialBottleneck
def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spatial_group_size, in_channels, bottleneck_channels, out_channels, num_groups, stride_in_1x1, stride, dilation, norm_func, nhwc):
......@@ -22,16 +23,25 @@ def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spat
# fast = False
if fast:
bottleneck = FastBottleneck(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True)
if spatial_group_size > 1:
print("WARNING! spatial_group_size ignored by FastBottleneck")
if spatial_group_size == 1:
bottleneck = FastBottleneck(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True)
else:
bottleneck = SpatialBottleneck(
in_channels=in_channels,
bottleneck_channels=bottleneck_channels,
out_channels=out_channels,
stride=stride,
dilation=dilation,
explicit_nhwc=nhwc,
use_cudnn=True,
spatial_group_size=spatial_group_size)
else:
bottleneck = Bottleneck(
in_channels,
......@@ -89,15 +99,24 @@ def single_module_test(ref, rank, world_size, numtype, device, shape, fast, spat
for k in weights.keys():
torch.distributed.broadcast(weights[k],0)
else:
# gather dgrad (x.grad), sum wgrad (weights)
# gather dgrad (x.grad), sum wgrad (weights) and out
N,Hs,W,C = dgrad.shape
H = Hs * spatial_group_size
dgrad_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
dgrad_tensors = [dgrad_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(dgrad_tensors, dgrad)
dgrad = dgrad_gathered
N,Hs,W,C = list(out.shape)
H = Hs * spatial_group_size
out_gathered = torch.empty((N,H,W,C),dtype=dgrad.dtype,device=dgrad.device)
out_tensors= [out_gathered[:,i*Hs:(i+1)*Hs,:,:] for i in range(spatial_group_size)]
torch.distributed.all_gather(out_tensors, out)
out = out_gathered
for k in wgrad.keys():
torch.distributed.all_reduce(wgrad[k])
w = wgrad[k].to(dtype=torch.float64)
torch.distributed.all_reduce(w)
wgrad[k].copy_(w.to(dtype=wgrad[k].dtype))
#torch.distributed.all_reduce(wgrad[k])
return x, out, grad_out, weights, dgrad, wgrad
......@@ -118,7 +137,7 @@ def module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, i
ref = x, grad_out, weights
if rank == 0:
rr.append( (out, dgrad, wgrad) )
torch.distributed.barrier()
if world_size > 1: torch.distributed.barrier()
r.append(rr)
return r
......@@ -138,11 +157,11 @@ def main():
else:
rank, local_rank, is_master, world_size, spatial_group_size = 0, 0, True, 1, 1
#torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = True
#torch.backends.cudnn.deterministic = True
#torch.backends.cuda.matmul.allow_tf32 = False
#torch.backends.cudnn.allow_tf32 = False
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
norm_func = FrozenBatchNorm2d_NHWC
......@@ -164,6 +183,7 @@ def main():
(1, 84, 50, 1024, 1024, 512, 2048, 1, True, 2, 1, norm_func, True),
(1, 42, 25, 2048, 2048, 512, 2048, 1, True, 1, 1, norm_func, True),
]
init_args = init_args[0:1]
# pad H to account for spatial distribution
padded_init_args = []
......@@ -182,16 +202,70 @@ def main():
if spatial_group_size > 1:
spatial_group_sizes.append(spatial_group_size)
numtype, device, fast = torch.float16, 'cuda', False
numtype, device, fast = torch.float16, 'cuda', True
r = module_tests(rank, world_size, numtype, device, fast, spatial_group_sizes, init_args)
torch.distributed.barrier()
if world_size > 1: torch.distributed.barrier()
if rank == 0:
for rr in r:
print("***")
for out, dgrad, wgrad in rr:
gr = [("dgrad",dgrad.norm(p=2,dtype=torch.float64).item())] + [(k+".wgrad",wgrad[k].norm(p=2,dtype=torch.float64).item()) for k in wgrad.keys()]
gr = [("out",out.norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",dgrad.norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",wgrad[k].norm(p=2,dtype=torch.float64).item()) for k in wgrad.keys()]
print(gr)
if len(rr) == 2:
out1, dgrad1, wgrad1 = rr[0]
out2, dgrad2, wgrad2 = rr[1]
rtol = 1e-1
out_atol = out1.abs().max().item() * rtol
dgrad_atol = dgrad1.abs().max().item() * rtol
wgrad_atol = {}
for k in wgrad1.keys():
wgrad_atol[k] = wgrad1[k].abs().max().item() * rtol
gr = [("out",torch.allclose(out1,out2,rtol,out_atol,equal_nan=True))]
gr = gr + [("dgrad",torch.allclose(dgrad1,dgrad2,rtol,dgrad_atol,equal_nan=True))]
gr = gr + [(k+".wgrad",torch.allclose(wgrad1[k],wgrad2[k],rtol,wgrad_atol[k],equal_nan=True)) for k in wgrad1.keys()]
print(gr)
gr = [("out",(out1-out2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [("dgrad",(dgrad1-dgrad2).norm(p=2,dtype=torch.float64).item())]
gr = gr + [(k+".wgrad",(wgrad1[k]-wgrad2[k]).norm(p=2,dtype=torch.float64).item()) for k in wgrad1.keys()]
print(gr)
torch.distributed.barrier()
N,H,W,C = out1.shape
Hs = H // spatial_group_size
Ht = Hs-2
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs-1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
Ht = Hs+1
print("out1@%d:%d=%s" % (Ht,H,str(out1[0,Ht,:8,:5])))
print("out2@%d:%d=%s" % (Ht,H,str(out2[0,Ht,:8,:5])))
N,H,W,C = dgrad1.shape
Hs = H // spatial_group_size
Ht = Hs-2
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
Ht = Hs-1
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
Ht = Hs
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
Ht = Hs+1
print("dgrad1@%d:%d=%s" % (Ht,H,str(dgrad1[0,Ht,:8,:5])))
print("dgrad2@%d:%d=%s" % (Ht,H,str(dgrad2[0,Ht,:8,:5])))
if world_size > 1: torch.distributed.barrier()
if __name__ == "__main__":
......
Supports Markdown
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册