Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
S
Swin-Transformer-Object-Detection
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to JiHu GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
wanggh
Swin-Transformer-Object-Detection
Commits
b7968de7
Commit
b7968de7
authored
6 years ago
by
Kai Chen
Browse files
Options
Downloads
Patches
Plain Diff
modify MMDistributedDataParallel, no longer inherited from DistributedDataParallel
parent
e74c260f
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
mmdet/nn/parallel/distributed.py
+42
-2
42 additions, 2 deletions
mmdet/nn/parallel/distributed.py
tools/dist_train.sh
+1
-5
1 addition, 5 deletions
tools/dist_train.sh
tools/train.py
+1
-4
1 addition, 4 deletions
tools/train.py
with
44 additions
and
11 deletions
mmdet/nn/parallel/distributed.py
+
42
−
2
View file @
b7968de7
from
torch.nn.parallel
import
DistributedDataParallel
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch._utils
import
(
_flatten_dense_tensors
,
_unflatten_dense_tensors
,
_take_tensors
)
from
.scatter_gather
import
scatter_kwargs
class
MMDistributedDataParallel
(
DistributedDataParallel
):
class
MMDistributedDataParallel
(
nn
.
Module
):
def
__init__
(
self
,
module
,
dim
=
0
,
broadcast_buffers
=
True
):
super
(
MMDistributedDataParallel
,
self
).
__init__
()
self
.
module
=
module
self
.
dim
=
dim
self
.
broadcast_buffers
=
broadcast_buffers
self
.
first_synced
=
False
self
.
broadcast_bucket_size
=
32
*
1024
*
1024
def
_dist_broadcast_coalesced
(
self
,
tensors
,
buffer_size
):
for
tensors
in
_take_tensors
(
tensors
,
buffer_size
):
flat_tensors
=
_flatten_dense_tensors
(
tensors
)
dist
.
broadcast
(
flat_tensors
,
0
)
for
tensor
,
synced
in
zip
(
tensors
,
_unflatten_dense_tensors
(
flat_tensors
,
tensors
)):
tensor
.
copy_
(
synced
)
def
sync_params
(
self
):
module_states
=
list
(
self
.
module
.
state_dict
().
values
())
if
len
(
module_states
)
>
0
:
self
.
_dist_broadcast_coalesced
(
module_states
,
self
.
broadcast_bucket_size
)
if
self
.
broadcast_buffers
:
buffers
=
[
b
.
data
for
b
in
self
.
module
.
_all_buffers
()]
if
len
(
buffers
)
>
0
:
self
.
_dist_broadcast_coalesced
(
buffers
,
self
.
broadcast_bucket_size
)
def
scatter
(
self
,
inputs
,
kwargs
,
device_ids
):
return
scatter_kwargs
(
inputs
,
kwargs
,
device_ids
,
dim
=
self
.
dim
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
not
self
.
first_synced
:
self
.
sync_params
()
self
.
first_synced
=
True
inputs
,
kwargs
=
self
.
scatter
(
inputs
,
kwargs
,
[
torch
.
cuda
.
current_device
()])
return
self
.
module
(
*
inputs
[
0
],
**
kwargs
[
0
])
This diff is collapsed.
Click to expand it.
tools/dist_train.sh
+
1
−
5
View file @
b7968de7
...
...
@@ -2,8 +2,4 @@
PYTHON
=
${
PYTHON
:-
"python"
}
$PYTHON
train.py
$1
--dist
--world-size
$2
--rank
0 &
let
MAX_RANK
=
$2
-1
for
i
in
`
seq
1
$MAX_RANK
`
;
do
$PYTHON
train.py
$1
--dist
--world-size
$2
--rank
$i
>
/dev/null 2>&1 &
done
$PYTHON
-m
torch.distributed.launch
--nproc_per_node
=
$2
train.py
$1
--launcher
pytorch
\ No newline at end of file
This diff is collapsed.
Click to expand it.
tools/train.py
+
1
−
4
View file @
b7968de7
...
...
@@ -95,10 +95,7 @@ def main():
model
=
build_detector
(
cfg
.
model
,
train_cfg
=
cfg
.
train_cfg
,
test_cfg
=
cfg
.
test_cfg
)
if
dist
:
model
=
MMDistributedDataParallel
(
model
,
device_ids
=
[
torch
.
cuda
.
current_device
()],
broadcast_buffers
=
False
).
cuda
()
model
=
MMDistributedDataParallel
(
model
).
cuda
()
else
:
model
=
MMDataParallel
(
model
,
device_ids
=
range
(
cfg
.
gpus
)).
cuda
()
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment