|
from header import * |
|
from dataset import load_dataset |
|
from model import * |
|
from config import * |
|
|
|
|
|
def parser_args(): |
|
parser = argparse.ArgumentParser(description='train parameters') |
|
parser.add_argument('--model', type=str, default='nextgpt') |
|
parser.add_argument('--mode', type=str, default='train', help='train or test or validation') |
|
parser.add_argument('--local_rank', default=0, type=int) |
|
parser.add_argument('--save_path', type=str, default='../ckpt/delta_ckpt/nextgpt/7b_tiva_v0/') |
|
parser.add_argument('--log_path', type=str, default='../ckpt/delta_ckpt/nextgpt/7b_tiva_v0/log/') |
|
parser.add_argument('--assets_path', type=str, default='./assets/') |
|
|
|
|
|
parser.add_argument('--max_length', type=int, default=512) |
|
parser.add_argument('--stage', type=int, default=1) |
|
parser.add_argument('--modality', type=list, default=['image', 'video', 'audio', 'text']) |
|
return parser.parse_args() |
|
|
|
|
|
def initialize_distributed(args): |
|
args['master_ip'] = os.getenv('MASTER_ADDR', 'localhost') |
|
args['master_port'] = os.getenv('MASTER_PORT', '6000') |
|
args['world_size'] = int(os.getenv('WORLD_SIZE', '1')) |
|
args['local_rank'] = int(os.getenv('RANK', '0')) % torch.cuda.device_count() |
|
device = args['local_rank'] % torch.cuda.device_count() |
|
torch.cuda.set_device(device) |
|
deepspeed.init_distributed(dist_backend='nccl') |
|
|
|
|
|
def set_random_seed(seed): |
|
if seed is not None and seed > 0: |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.random.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def config_env(args): |
|
args['root_dir'] = '../' |
|
|
|
config = load_config(args) |
|
args.update(config) |
|
initialize_distributed(args) |
|
set_random_seed(args['seed']) |
|
|
|
|
|
def build_directory(path): |
|
if os.path.exists(path): |
|
pass |
|
else: |
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
def main(**args): |
|
config_env(args) |
|
print(args) |
|
args['ds_config_path'] = f'dsconfig/stage_{args["stage"]}.json' |
|
dschf = HfDeepSpeedConfig(args['ds_config_path']) |
|
args['dschf'] = dschf |
|
|
|
build_directory(args['save_path']) |
|
build_directory(args['log_path']) |
|
|
|
if args['log_path']: |
|
logging.basicConfig( |
|
format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', |
|
level=logging.DEBUG, |
|
filename=f'{args["log_path"]}/train_{time.asctime()}.log', |
|
filemode='w' |
|
) |
|
train_data, train_iter, sampler = load_dataset(args, args['dataset_name_list']) |
|
|
|
train_num = max([_cur_dataset.__len__() for _cur_dataset in train_data.datasets.datasets]) * len(train_data.datasets.datasets) |
|
length = args['epochs'] * train_num // args['world_size'] // dschf.config[ |
|
'train_micro_batch_size_per_gpu'] |
|
total_steps = args['epochs'] * train_num // dschf.config['train_batch_size'] |
|
args['total_steps'] = total_steps |
|
agent = load_model(args) |
|
torch.distributed.barrier() |
|
|
|
|
|
pbar = tqdm(total=length) |
|
current_step = 0 |
|
for epoch_i in tqdm(range(args['epochs'])): |
|
|
|
for batch in train_iter: |
|
agent.train_model( |
|
batch, |
|
current_step=current_step, |
|
pbar=pbar |
|
) |
|
current_step += 1 |
|
|
|
|
|
|
|
|
|
torch.distributed.barrier() |
|
agent.save_model(args['save_path'], current_step) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser_args() |
|
args = vars(args) |
|
main(**args) |
|
|