File size: 3,512 Bytes
320e465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import json
import argparse
import subprocess

from shutil import copyfile
import torch.distributed as dist

import torch
import torch.multiprocessing as mp

import core
import core.trainer
import core.trainer_flow_w_edge


# import warnings
# warnings.filterwarnings("ignore")

from core.dist import (
    get_world_size,
    get_local_rank,
    get_global_rank,
    get_master_ip,
)

parser = argparse.ArgumentParser()
parser.add_argument('-c',
                    '--config',
                    default='configs/train_propainter.json',
                    type=str)
parser.add_argument('-p', '--port', default='23490', type=str)
args = parser.parse_args()


def main_worker(rank, config):
    if 'local_rank' not in config:
        config['local_rank'] = config['global_rank'] = rank
    if config['distributed']:
        torch.cuda.set_device(int(config['local_rank']))
        torch.distributed.init_process_group(backend='nccl',
                                             init_method=config['init_method'],
                                             world_size=config['world_size'],
                                             rank=config['global_rank'],
                                             group_name='mtorch')
        print('using GPU {}-{} for training'.format(int(config['global_rank']),
                                                    int(config['local_rank'])))


    config['save_dir'] = os.path.join(
        config['save_dir'],
        '{}_{}'.format(config['model']['net'],
                       os.path.basename(args.config).split('.')[0]))

    config['save_metric_dir'] = os.path.join(
        './scores',
        '{}_{}'.format(config['model']['net'],
                       os.path.basename(args.config).split('.')[0]))

    if torch.cuda.is_available():
        config['device'] = torch.device("cuda:{}".format(config['local_rank']))
    else:
        config['device'] = 'cpu'

    if (not config['distributed']) or config['global_rank'] == 0:
        os.makedirs(config['save_dir'], exist_ok=True)
        config_path = os.path.join(config['save_dir'],
                                   args.config.split('/')[-1])
        if not os.path.isfile(config_path):
            copyfile(args.config, config_path)
        print('[**] create folder {}'.format(config['save_dir']))

    trainer_version = config['trainer']['version']
    trainer = core.__dict__[trainer_version].__dict__['Trainer'](config)
    # Trainer(config)
    trainer.train()


if __name__ == "__main__":

    torch.backends.cudnn.benchmark = True

    mp.set_sharing_strategy('file_system')

    # loading configs
    config = json.load(open(args.config))

    # setting distributed configurations
    # config['world_size'] = get_world_size()
    config['world_size'] = torch.cuda.device_count()
    config['init_method'] = f"tcp://{get_master_ip()}:{args.port}"
    config['distributed'] = True if config['world_size'] > 1 else False
    print('world_size:', config['world_size'])
    # setup distributed parallel training environments

    # if get_master_ip() == "127.0.0.X":
    #     # manually launch distributed processes
    #     mp.spawn(main_worker, nprocs=config['world_size'], args=(config, ))
    # else:
    #     # multiple processes have been launched by openmpi
    #     config['local_rank'] = get_local_rank()
    #     config['global_rank'] = get_global_rank()
    #     main_worker(-1, config)

    mp.spawn(main_worker, nprocs=torch.cuda.device_count(), args=(config, ))