-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain.py
68 lines (54 loc) · 2.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# python3.7
"""Main function for model training."""
import click
from configs import CONFIG_POOL
from configs import build_config
from runners import build_runner
from utils.dist_utils import init_dist
from utils.dist_utils import exit_dist
@click.group(name='Distributed Training',
help='Train a deep model by choosing a command (configuration).',
context_settings={'show_default': True, 'max_content_width': 180})
@click.option('--launcher', default='pytorch',
type=click.Choice(['pytorch', 'slurm']),
help='Distributed launcher.')
@click.option('--backend', default='nccl',
type=click.Choice(['nccl', 'gloo', 'mpi']),
help='Distributed backend.')
@click.option('--local_rank', type=int, default=0, hidden=True,
help='Replica rank on the current node. This field is required '
'by `torch.distributed.launch`.')
def command_group(launcher, backend, local_rank): # pylint: disable=unused-argument
"""Defines a command group for launching distributed jobs.
This function is mainly for interaction with the command line. The real
launching is executed by `main()` function, through `result_callback()`
decorator. In other words, the arguments obtained from the command line will
be passed to `main()` function. As for how the arguments are passed, it is
the responsibility of each command of this command group. Please refer to
`BaseConfig.get_command()` in `configs/base_config.py` for more details.
"""
@command_group.result_callback()
@click.pass_context
def main(ctx, kwargs, launcher, backend, local_rank):
"""Main function for distributed training.
Basically, this function first initializes a distributed environment, then
parses configuration from the command line, and finally sets up the runner
with the parsed configuration for training.
"""
_ = local_rank # unused variable
# Initialize distributed environment.
init_dist(launcher=launcher, backend=backend)
# Build configurations and runner.
config = build_config(ctx.invoked_subcommand, kwargs).get_config()
runner = build_runner(config)
# Start training.
runner.train()
runner.close()
# Exit distributed environment.
exit_dist()
if __name__ == '__main__':
# Append all available commands (from `configs/`) into the command group.
for cfg in CONFIG_POOL:
command_group.add_command(cfg.get_command())
# Run by interacting with command line.
command_group() # pylint: disable=no-value-for-parameter