Skip to content
GitLab
菜单
项目
群组
代码片段
帮助
帮助
支持
社区论坛
快捷键
?
提交反馈
登录/注册
切换导航
菜单
打开侧边栏
wanggh
apex
提交
b6980a0d
提交
b6980a0d
编辑于
9月 01, 2021
作者:
Thor Johnsen
浏览文件
Add functions to compute grad_out1, grad_out1_halo
上级
ed713c84
变更
2
Hide whitespace changes
Inline
Side-by-side
apex/contrib/bottleneck/bottleneck.py
浏览文件 @
b6980a0d
...
...
@@ -237,8 +237,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
if
spatial_group_size
>
1
:
out1
=
outputs
[
0
]
N
,
Hs
,
W
,
C
=
list
(
out1
.
shape
)
padded_out1
=
torch
.
empty
((
N
,
Hs
+
2
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
padded_out1
[:,
1
:
Hs
+
1
,:,:].
copy_
(
out1
)
stream1
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
stream1
):
# copy halos to send buffer
...
...
@@ -248,22 +246,17 @@ class SpatialBottleneckFunction(torch.autograd.Function):
all_halos
=
torch
.
empty
((
N
,
2
*
spatial_group_size
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
all_halos
=
[
all_halos
[:,
i
*
2
:(
i
+
1
)
*
2
,:,:]
for
i
in
range
(
spatial_group_size
)]
dist
.
all_gather
(
all_halos
,
send_halos
)
padded_out1_top_halo
=
padded_out1
[:,:
1
,:,:]
fat_halo
=
torch
.
empty
((
N
,
3
,
W
,
C
),
dtype
=
out1
.
dtype
,
device
=
out1
.
device
)
if
local_rank
>
0
:
top_halo
=
all_halos
[
local_rank
-
1
][:,
1
:,:,:]
padded_out1_top_halo
.
copy_
(
top_halo
)
fat_top_halo
=
padded_out1
[:,:
3
,:,:]
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_top_halo
,
args
)
else
:
padded_out1_top_halo
.
zero_
()
padded_out1_btm_halo
=
padded_out1
[:,
Hs
+
1
:,:,:]
fat_halo
[:,:
1
,:,:].
copy_
(
top_halo
)
fat_halo
[:,
1
:
3
,:,:].
copy_
(
out1
[:,:
2
,:,:])
top_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
if
local_rank
<
spatial_group_size
-
1
:
btm_halo
=
all_halos
[
local_rank
+
1
][:,:
1
,:,:]
padded_out1_btm_halo
.
copy_
(
btm_halo
)
fat_btm_halo
=
padded_out1
[:,
Hs
-
1
:,:,:]
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_btm_halo
,
args
)
else
:
padded_out1_btm_halo
.
zero_
()
fat_halo
[:,
0
:
2
,:,:].
copy_
(
out1
[:,
Hs
-
2
:,:,:])
fat_halo
[:,
2
:,:,:].
copy_
(
btm_halo
)
btm_out2
=
fast_bottleneck
.
forward_out2_halo
(
nhwc
,
fat_halo
,
args
)
torch
.
cuda
.
current_stream
().
wait_stream
(
stream1
)
out2
=
outputs
[
1
]
if
local_rank
>
0
:
...
...
@@ -272,10 +265,8 @@ class SpatialBottleneckFunction(torch.autograd.Function):
out2
[:,
Hs
-
1
:,:,:].
copy_
(
btm_out2
)
fast_bottleneck
.
forward_rest
(
nhwc
,
stride_1x1
,
args
,
outputs
)
if
spatial_group_size
>
1
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
+
[
padded_out1
]))
else
:
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# TODO: save halos for backward pass
ctx
.
save_for_backward
(
*
(
args
+
outputs
))
# save relu outputs for drelu
ctx
.
nhwc
=
nhwc
ctx
.
stride_1x1
=
stride_1x1
...
...
@@ -289,10 +280,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# only support dgrad
@
staticmethod
def
backward
(
ctx
,
grad_o
):
if
ctx
.
spatial_group_size
>
1
:
outputs
=
ctx
.
saved_tensors
[
-
4
:
-
1
]
else
:
outputs
=
ctx
.
saved_tensors
[
-
3
:]
outputs
=
ctx
.
saved_tensors
[
-
3
:]
if
ctx
.
downsample
:
grad_conv3
,
grad_conv4
=
drelu_dscale2
(
grad_o
,
outputs
[
2
],
ctx
.
saved_tensors
[
6
],
ctx
.
saved_tensors
[
11
])
...
...
@@ -315,7 +303,23 @@ class SpatialBottleneckFunction(torch.autograd.Function):
grads
=
fast_bottleneck
.
backward_init
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
)
grad_out2
=
fast_bottleneck
.
backward_grad_out2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
)
# do halo exchange of grad_out2 here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# need fast_bottleneck.backward_grad_out2_halo
# testing
N
,
H
,
W
,
C
=
grad_out2
.
shape
grad_out2_halo
=
torch
.
empty
([
N
,
3
,
W
,
C
],
dtype
=
grad_out2
.
dtype
,
device
=
grad_out2
.
device
)
grad_out2_halo
[:,:
1
,:,:].
zero_
()
grad_out2_halo
[:,
1
:,:,:].
copy_
(
grad_out2
[:,:
2
,:,:])
grad_out1_halo
=
fast_bottleneck
.
backward_grad_out1_halo
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2_halo
)
# print("grad_out2_halo.shape = %s -> grad_out1_halo.shape = %s" % (str(list(grad_out2_halo.shape)), str(list(grad_out1_halo.shape))))
wgrad2
=
fast_bottleneck
.
backward_wgrad2
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply wgrad2 halos here
# no need for custom wgrad2_halo function, this is just a backwards data convolution
grad_out1
=
fast_bottleneck
.
backward_grad_out1
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
)
# apply grad_out1 halos here
fast_bottleneck
.
backward_rest
(
ctx
.
nhwc
,
ctx
.
stride_1x1
,
t_list
,
grads
,
grad_out2
,
grad_out1
,
wgrad2
)
return
(
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
*
grads
)
...
...
apex/contrib/csrc/bottleneck/bottleneck.cpp
浏览文件 @
b6980a0d
...
...
@@ -1746,7 +1746,7 @@ std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1
std
::
vector
<
at
::
Tensor
>
outputs
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
printf
(
"outdim1 = (%d,%d,%d,%d)
\n
"
,
forward_state
.
outdim1
[
0
],
forward_state
.
outdim1
[
1
],
forward_state
.
outdim1
[
2
],
forward_state
.
outdim1
[
3
]);
//
printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
auto
out1
=
at
::
empty
(
forward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
auto
out2
=
at
::
empty
(
forward_state
.
outdim2
,
inputs
[
0
].
type
(),
output_format
);
auto
out3
=
at
::
empty
(
forward_state
.
outdim3
,
inputs
[
0
].
type
(),
output_format
);
...
...
@@ -1837,12 +1837,12 @@ void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at:
auto
out2
=
outputs
[
1
];
at
::
Half
*
y2
=
out2
.
data_ptr
<
at
::
Half
>
();
printf
(
"forward_state.outdimA1 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA1
[
0
],
forward_state
.
outdimA1
[
1
],
forward_state
.
outdimA1
[
2
],
forward_state
.
outdimA1
[
3
]);
printf
(
"forward_state.padA1 = {%d,%d}
\n
"
,
forward_state
.
padA1
[
0
],
forward_state
.
padA1
[
1
]);
printf
(
"forward_state.convstrideA = {%d,%d}
\n
"
,
forward_state
.
convstrideA
[
0
],
forward_state
.
convstrideA
[
1
]);
printf
(
"forward_state.dilationA = {%d,%d}
\n
"
,
forward_state
.
dilationA
[
0
],
forward_state
.
dilationA
[
1
]);
printf
(
"forward_state.filterdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
filterdimA2
[
0
],
forward_state
.
filterdimA2
[
1
],
forward_state
.
filterdimA2
[
2
],
forward_state
.
filterdimA2
[
3
]);
printf
(
"forward_state.outdimA2 = {%d,%d,%d,%d}
\n
"
,
forward_state
.
outdimA2
[
0
],
forward_state
.
outdimA2
[
1
],
forward_state
.
outdimA2
[
2
],
forward_state
.
outdimA2
[
3
]);
//
printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//
printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//
printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//
printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//
printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//
printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation
(
forward_state
.
outdimA1
,
forward_state
.
padA1
,
forward_state
.
convstrideA
,
...
...
@@ -1934,12 +1934,15 @@ struct bottleneck_backward_state {
int
axis
[
4
];
int64_t
outdimA1
[
4
];
int64_t
outdimA2
[
4
];
int64_t
outdimA1
[
4
];
// grad_out1
int64_t
outdimA2
[
4
];
// grad_out2
int64_t
outdimA3
[
4
];
int64_t
outdimA1h
[
4
];
// output: grad_out1 halo (H=3)
int64_t
outdimA2h
[
4
];
// input : grad_out2 halo cells (H=3)
int64_t
padA
[
2
];
int64_t
padA1
[
2
];
int64_t
padA2
[
2
];
int64_t
dilationA
[
2
];
int64_t
convstrideA
[
2
];
int64_t
convstride1X1
[
2
];
...
...
@@ -1947,6 +1950,8 @@ struct bottleneck_backward_state {
int64_t
outdim1
[
4
];
int64_t
outdim2
[
4
];
int64_t
outdim3
[
4
];
int64_t
outdim1h
[
4
];
int64_t
outdim2hh
[
4
];
void
init
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
)
{
// setup dimensions
...
...
@@ -1985,6 +1990,8 @@ struct bottleneck_backward_state {
outdimA1
[
0
]
=
outdimA1
[
1
]
=
outdimA1
[
2
]
=
outdimA1
[
3
]
=
0
;
outdimA2
[
0
]
=
outdimA2
[
1
]
=
outdimA2
[
2
]
=
outdimA2
[
3
]
=
0
;
outdimA3
[
0
]
=
outdimA3
[
1
]
=
outdimA3
[
2
]
=
outdimA3
[
3
]
=
0
;
outdimA1h
[
0
]
=
outdimA1h
[
1
]
=
outdimA1h
[
2
]
=
outdimA1h
[
3
]
=
0
;
outdimA2h
[
0
]
=
outdimA2h
[
1
]
=
outdimA2h
[
2
]
=
outdimA2h
[
3
]
=
0
;
// use these fixed value for test run
padA
[
0
]
=
0
;
padA
[
1
]
=
0
;
...
...
@@ -2012,10 +2019,21 @@ struct bottleneck_backward_state {
outdimA3
[
dim
+
2
]
=
getFwdConvOutputDim
(
outdimA2
[
dim
+
2
],
padA
[
dim
],
filterdimA3
[
dim
+
2
],
convstrideA
[
dim
],
dilationA
[
dim
]);
}
for
(
int
dim
=
0
;
dim
<
4
;
dim
++
)
{
if
(
dim
==
2
)
{
outdimA1h
[
dim
]
=
3
;
outdimA2h
[
dim
]
=
3
;
}
else
{
outdimA1h
[
dim
]
=
outdimA1
[
dim
];
outdimA2h
[
dim
]
=
outdimA2
[
dim
];
}
}
// Create output tensor in the correct shape in pytorch's view
outdim1
[
0
]
=
outdim1
[
1
]
=
outdim1
[
2
]
=
outdim1
[
3
]
=
0
;
outdim2
[
0
]
=
outdim2
[
1
]
=
outdim2
[
2
]
=
outdim2
[
3
]
=
0
;
outdim3
[
0
]
=
outdim3
[
1
]
=
outdim3
[
2
]
=
outdim3
[
3
]
=
0
;
outdim1h
[
0
]
=
outdim1h
[
1
]
=
outdim1h
[
2
]
=
outdim1h
[
3
]
=
0
;
if
(
explicit_nhwc
)
{
axis
[
0
]
=
0
;
axis
[
1
]
=
2
;
...
...
@@ -2026,6 +2044,7 @@ struct bottleneck_backward_state {
outdim1
[
dim
]
=
outdimA1
[
axis
[
dim
]];
outdim2
[
dim
]
=
outdimA2
[
axis
[
dim
]];
outdim3
[
dim
]
=
outdimA3
[
axis
[
dim
]];
outdim1h
[
dim
]
=
outdimA1h
[
axis
[
dim
]];
}
}
};
...
...
@@ -2117,7 +2136,78 @@ at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std
return
grad_out2
;
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
at
::
Tensor
bottleneck_backward_grad_out1
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
();
// dgrad
auto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
return
grad_out1
;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at
::
Tensor
bottleneck_backward_grad_out1_halo
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2_halo
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
at
::
Half
*
dy2h
=
grad_out2_halo
.
data_ptr
<
at
::
Half
>
();
// dgrad
auto
grad_out1_halo
=
at
::
empty
(
backward_state
.
outdim1h
,
inputs
[
0
].
type
(),
output_format
);
at
::
Half
*
dy1h
=
grad_out1_halo
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
run_dconv_drelu_dscale
(
backward_state
.
outdimA1h
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2h
,
CUDNN_DATA_HALF
,
dy1h
,
w
,
dy2h
,
z
,
relu1
);
return
grad_out1_halo
;
}
at
::
Tensor
bottleneck_backward_wgrad2
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
...
...
@@ -2134,7 +2224,7 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
auto
wgrad2
=
outputs
[
2
];
at
::
Half
*
dw2
=
wgrad2
.
data_ptr
<
at
::
Half
>
();
printf
(
"outdimA1 = (%d,%d,%d,%d)
\n
"
,
backward_state
.
outdimA1
[
0
],
backward_state
.
outdimA1
[
1
],
backward_state
.
outdimA1
[
2
],
backward_state
.
outdimA1
[
3
]);
//
printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
...
...
@@ -2147,26 +2237,19 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
dy2
,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
);
return
wgrad2
;
}
void
bottleneck_backward_rest
(
bool
explicit_nhwc
,
int
stride_1X1
,
std
::
vector
<
at
::
Tensor
>
inputs
,
std
::
vector
<
at
::
Tensor
>
outputs
,
at
::
Tensor
grad_out2
,
at
::
Tensor
grad_out1
,
at
::
Tensor
wgrad2
)
{
bool
requires_grad
=
inputs
[
0
].
requires_grad
();
std
::
cout
<<
std
::
fixed
;
auto
output_format
=
explicit_nhwc
?
at
::
MemoryFormat
::
Contiguous
:
at
::
MemoryFormat
::
ChannelsLast
;
// dgrad
a
uto
grad_out1
=
at
::
empty
(
backward_state
.
outdim1
,
inputs
[
0
].
type
(),
output_format
);
a
t
::
Half
*
dy2
=
grad_out2
.
data_ptr
<
at
::
Half
>
(
);
at
::
Half
*
dy1
=
grad_out1
.
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
inputs
[
2
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
z
=
inputs
[
4
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
relu1
=
inputs
[
12
].
data_ptr
<
at
::
Half
>
();
// fused dgrad
run_dconv_drelu_dscale
(
backward_state
.
outdimA1
,
backward_state
.
padA1
,
backward_state
.
convstrideA
,
backward_state
.
dilationA
,
backward_state
.
filterdimA2
,
backward_state
.
outdimA2
,
CUDNN_DATA_HALF
,
dy1
,
w
,
dy2
,
z
,
relu1
);
/*
// backward strided conv cannot be fused
...
...
@@ -2215,6 +2298,8 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at
// x used for dconv1 and dconv4 wgrad
at
::
Half
*
x
=
inputs
[
0
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
w
=
NULL
;
if
(
stride_1X1
!=
1
||
backward_state
.
filterdimA3
[
0
]
!=
backward_state
.
dimA
[
1
]){
w
=
inputs
[
14
].
data_ptr
<
at
::
Half
>
();
at
::
Half
*
dy_conv4
=
inputs
[
11
].
data_ptr
<
at
::
Half
>
();
...
...
@@ -2327,5 +2412,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"forward_rest"
,
&
bottleneck_forward_rest
,
"Bottleneck block forward"
);
m
.
def
(
"backward_init"
,
&
bottleneck_backward_init
,
"Bottleneck block backward init"
);
m
.
def
(
"backward_grad_out2"
,
&
bottleneck_backward_grad_out2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1"
,
&
bottleneck_backward_grad_out1
,
"Bottleneck block backward"
);
m
.
def
(
"backward_grad_out1_halo"
,
&
bottleneck_backward_grad_out1_halo
,
"Bottleneck block backward"
);
m
.
def
(
"backward_wgrad2"
,
&
bottleneck_backward_wgrad2
,
"Bottleneck block backward"
);
m
.
def
(
"backward_rest"
,
&
bottleneck_backward_rest
,
"Bottleneck block backward"
);
}
编辑
预览
Supports
Markdown
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录