Spaces:
Running
Running
# -*- encoding: utf-8 -*- | |
''' | |
@File : pretrain_cogvideo.py | |
@Time : 2021/10/06 00:58:32 | |
@Author : Wenyi Hong | |
@Contact : hwy22@mails.tsinghua.edu.cn | |
''' | |
# here put the import lib | |
import os | |
import sys | |
import math | |
import random | |
import torch | |
import argparse | |
import numpy as np | |
from icetk import icetk as tokenizer | |
tokenizer.add_special_tokens(['<start_of_image>', '<start_of_english>', '<start_of_chinese>']) | |
from models.cogvideo_model import CogVideoModel | |
from SwissArmyTransformer import mpu, get_args | |
from SwissArmyTransformer.training.deepspeed_training import training_main | |
from SwissArmyTransformer.data_utils import BinaryDataset | |
def get_masks_and_position_ids_video(data, attention_mask_totxt=None, args=None): | |
# Extract batch size and sequence length. | |
batch_size, seq_length = data.size() | |
assert attention_mask_totxt is not None | |
layout = args.layout | |
assert seq_length == layout[-1] | |
n_pads = layout[0] - attention_mask_totxt.sum(dim=-1).long() | |
frame_len = layout[1]-layout[0] | |
position_ids = torch.zeros(batch_size, layout[2], dtype=torch.long, | |
device=data.device) | |
for i in range(batch_size): | |
torch.arange(layout[0] - n_pads[i], out=position_ids[i, n_pads[i]:layout[0]], | |
dtype=torch.long, device=data.device) | |
torch.arange(512, 512+layout[2]-layout[0], | |
out=position_ids[i, layout[0]:], dtype=torch.long, device=data.device) | |
return position_ids | |
def get_batch(data_iterator, args, timers): | |
# Items and their type. | |
keys = ['text', 'loss_mask', 'attention_mask_totxt'] | |
datatype = torch.int64 | |
# Broadcast data. | |
timers('data loader').start() | |
if data_iterator is not None: | |
data = next(data_iterator) | |
else: | |
data = None | |
timers('data loader').stop() | |
data_b = mpu.broadcast_data(keys, data, datatype) | |
# Unpack. | |
tokens_ = data_b['text'].long() | |
loss_mask = data_b['loss_mask'].float() | |
attention_mask_totxt = data_b['attention_mask_totxt'].float() | |
labels = tokens_[:, 1:].clone().contiguous() | |
loss_mask = loss_mask[:, 1:].contiguous() | |
tokens = tokens_[:, :-1].clone().contiguous() | |
for idx in range(args.layout[0], args.layout[2], 400): | |
tokens[:, idx] = tokenizer['<start_of_image>'] | |
# Get the masks and postition ids. | |
position_ids = get_masks_and_position_ids_video( | |
tokens, | |
attention_mask_totxt=attention_mask_totxt, | |
args=args | |
) | |
attention_mask_totxt = attention_mask_totxt.unsqueeze(1).unsqueeze(1) | |
# Convert | |
if args.fp16: | |
attention_mask_totxt = attention_mask_totxt.half() | |
return tokens, labels, loss_mask, attention_mask_totxt, position_ids | |
def forward_step(data_iterator, model, args, timers): | |
"""Forward step.""" | |
# Get the batch. | |
timers('batch generator').start() | |
tokens, labels, loss_mask, attention_mask_totxt, position_ids = get_batch( | |
data_iterator, args, timers) | |
timers('batch generator').stop() | |
# Forward model. | |
logits, *mems = model(tokens, position_ids, attention_mask_totxt) | |
# ======= hyper params =======# | |
perframe_len = 400 | |
text_len=64 | |
frame_num = 5 | |
logits_img_tokens = logits[:, text_len:, :tokenizer.num_image_tokens].float().contiguous() | |
losses = mpu.vocab_parallel_cross_entropy(logits_img_tokens, labels[:, text_len:]) | |
# scaling loss mask | |
loss_mask = loss_mask[:, text_len:].reshape(-1) | |
losses_1d = losses.reshape(-1) * loss_mask | |
loss = torch.sum(losses_1d) / loss_mask.sum() | |
# ===================== Log partial losses ======================== # | |
log_loss_dict = {} | |
bs = losses.shape[0] | |
if args.cogvideo_stage == 1: | |
for i in range(frame_num): | |
log_loss_dict[f'AR_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) | |
else: | |
for i in range(1, frame_num-1): | |
log_loss_dict[f'ITP_f{i}_loss'] = losses[:, i*perframe_len:(i+1)*perframe_len].contiguous().reshape(-1).detach().sum() / max((perframe_len*bs), 1) | |
# ===================== END OF BLOCK ======================= # | |
return loss, log_loss_dict | |
def create_dataset_function(path, args): | |
dataset_layout = [64, 464, 2064] | |
input_layout = [64, 464, 2064] | |
# frame_num = 6 | |
# frame_interval = 2 # DEBUG!!! | |
def process_fn(row): | |
row = row.astype(np.int64) | |
text = row[:dataset_layout[0]] | |
frames = row[dataset_layout[0]:] | |
if text[0] == tokenizer['<pad>']: | |
text = text[1:] # due to our way of data processing | |
if args.cogvideo_stage == 1: | |
text, loss_mask, frames = make_text_video_generation(text, frames) | |
else: | |
text, loss_mask, frames = mask_video_frame_interpolation(text, frames) | |
n_pad = input_layout[0] - len(text) | |
parts = [ | |
np.array([tokenizer['<pad>']] * n_pad, dtype=np.int64), | |
text, | |
np.array([tokenizer['<start_of_image>']], dtype=np.int64), | |
frames, | |
] | |
ret = np.concatenate(parts, axis=0) | |
attention_mask_totxt = np.array([0] * n_pad + [1] * (input_layout[0]-n_pad)) | |
return {'text': ret, | |
'loss_mask': loss_mask, | |
'attention_mask_totxt': attention_mask_totxt, | |
} | |
return BinaryDataset(path, process_fn, length_per_sample=dataset_layout[-1]) | |
def make_text_video_generation(text, frames): | |
input_layout = [64, 464, 2064] | |
text = text[text!= tokenizer['<pad>']][:input_layout[0]] # dataset format: 1.0秒<n>{text}<pad><pad> ... | |
loss_mask = np.array([0] * (input_layout[1]+1) + [1] * (input_layout[2] - input_layout[1])) # 按照input的,之后loss_mask会左移一位 | |
return text, loss_mask, frames | |
def mask_video_frame_interpolation(text, frames): | |
input_layout = [64, 464, 2064] | |
frame_len = input_layout[1]-input_layout[0] | |
# text format: <pad> 1.0秒 <n> {text} <pad> <pad> | |
text = text[text!= tokenizer['<pad>']][:input_layout[0]] | |
loss_mask = np.array([0] * (input_layout[1]+1) | |
+ [1] * (input_layout[1]-input_layout[0]) | |
+ [0] * (input_layout[1]-input_layout[0]) | |
+ [1] * (input_layout[1]-input_layout[0]) | |
+ [0] * (input_layout[1]-input_layout[0]) )# 按照input的,之后loss_mask会左移一位 | |
return text, loss_mask, frames | |
if __name__ == '__main__': | |
py_parser = argparse.ArgumentParser(add_help=False) | |
py_parser.add_argument('--txt-loss-scale', type=float, default=1) | |
CogVideoModel.add_model_specific_args(py_parser) | |
known, args_list = py_parser.parse_known_args() | |
args = get_args(args_list) | |
args = argparse.Namespace(**vars(args), **vars(known)) | |
args.layout = [int(x) for x in args.layout.split(',')] | |
training_main(args, model_cls=CogVideoModel, forward_step_function=forward_step, create_dataset_function=create_dataset_function) | |