File size: 1,171 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
from ding.entry import serial_pipeline_offline
from ding.config import read_config
from ding.utils import dist_init
from pathlib import Path
import torch
import torch.multiprocessing as mp


def offline_worker(rank, config, args):
    dist_init(rank=rank, world_size=torch.cuda.device_count())
    serial_pipeline_offline(config, seed=args.seed)


def train(args):
    # launch from anywhere
    config = Path(__file__).absolute().parent.parent / 'config' / args.config
    config = read_config(str(config))
    config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
    if not config[0].policy.multi_gpu:
        serial_pipeline_offline(config, seed=args.seed)
    else:
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "29600"
        mp.spawn(offline_worker, nprocs=torch.cuda.device_count(), args=(config, args))


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_ibc_config.py')
    args = parser.parse_args()
    train(args)