NEXTGPT / code /inference.py
osamaifti's picture
Upload 83 files
7cdf421 verified
raw
history blame
6.06 kB
import os
from model.anyToImageVideoAudio import NextGPTModel
import torch
import json
from config import *
import matplotlib.pyplot as plt
from diffusers.utils import export_to_video
import scipy
def predict(
input,
image_path=None,
audio_path=None,
video_path=None,
thermal_path=None,
max_tgt_len=200,
top_p=10.0,
temperature=0.1,
history=None,
modality_cache=None,
filter_value=-float('Inf'), min_word_tokens=0,
gen_scale_factor=10.0, max_num_imgs=1,
stops_id=None,
load_sd=True,
generator=None,
guidance_scale_for_img=7.5,
num_inference_steps_for_img=50,
guidance_scale_for_vid=7.5,
num_inference_steps_for_vid=50,
max_num_vids=1,
height=320,
width=576,
num_frames=24,
guidance_scale_for_aud=7.5,
num_inference_steps_for_aud=50,
max_num_auds=1,
audio_length_in_s=9,
ENCOUNTERS=1,
):
if image_path is None and audio_path is None and video_path is None and thermal_path is None:
# return [(input, "There is no input data provided! Please upload your data and start the conversation.")]
print('no image, audio, video, and thermal are input')
else:
print(
f'[!] image path: {image_path}\n[!] audio path: {audio_path}\n[!] video path: {video_path}\n[!] thermal path: {thermal_path}')
# prepare the prompt
prompt_text = ''
if history != None:
for idx, (q, a) in enumerate(history):
if idx == 0:
prompt_text += f'{q}\n### Assistant: {a}\n###'
else:
prompt_text += f' Human: {q}\n### Assistant: {a}\n###'
prompt_text += f'### Human: {input}'
else:
prompt_text += f'### Human: {input}'
print('prompt_text: ', prompt_text)
response = model.generate({
'prompt': prompt_text,
'image_paths': [image_path] if image_path else [],
'audio_paths': [audio_path] if audio_path else [],
'video_paths': [video_path] if video_path else [],
'thermal_paths': [thermal_path] if thermal_path else [],
'top_p': top_p,
'temperature': temperature,
'max_tgt_len': max_tgt_len,
'modality_embeds': modality_cache,
'filter_value': filter_value, 'min_word_tokens': min_word_tokens,
'gen_scale_factor': gen_scale_factor, 'max_num_imgs': max_num_imgs,
'stops_id': stops_id,
'load_sd': load_sd,
'generator': generator,
'guidance_scale_for_img': guidance_scale_for_img,
'num_inference_steps_for_img': num_inference_steps_for_img,
'guidance_scale_for_vid': guidance_scale_for_vid,
'num_inference_steps_for_vid': num_inference_steps_for_vid,
'max_num_vids': max_num_vids,
'height': height,
'width': width,
'num_frames': num_frames,
'guidance_scale_for_aud': guidance_scale_for_aud,
'num_inference_steps_for_aud': num_inference_steps_for_aud,
'max_num_auds': max_num_auds,
'audio_length_in_s': audio_length_in_s,
'ENCOUNTERS': ENCOUNTERS,
})
return response
if __name__ == '__main__':
# init the model
g_cuda = torch.Generator(device='cuda').manual_seed(1337)
args = {'model': 'nextgpt',
'nextgpt_ckpt_path': '../ckpt/delta_ckpt/nextgpt/7b_tiva_v0/',
'max_length': 128,
'stage': 3,
'root_dir': '../',
'mode': 'validate',
}
args.update(load_config(args))
model = NextGPTModel(**args)
delta_ckpt = torch.load(os.path.join(args['nextgpt_ckpt_path'], 'pytorch_model.pt'), map_location=torch.device('cuda'))
# print(delta_ckpt)
model.load_state_dict(delta_ckpt, strict=False)
model = model.eval().half().cuda()
# model = model.eval().cuda()
print(f'[!] init the 7b model over ...')
"""Override Chatbot.postprocess"""
max_tgt_length = 150
top_p = 1.0
temperature = 0.4
modality_cache = None
prompt = 'show me a video. a woman walk a dop in the park.'
history = []
output = predict(input=prompt, history=history,
max_tgt_len=max_tgt_length, top_p=top_p,
temperature=temperature, modality_cache=modality_cache,
filter_value=-float('Inf'), min_word_tokens=10,
gen_scale_factor=4.0, max_num_imgs=1,
stops_id=[[835]],
load_sd=True,
generator=g_cuda,
guidance_scale_for_img=7.5,
num_inference_steps_for_img=50,
guidance_scale_for_vid=7.5,
num_inference_steps_for_vid=50,
max_num_vids=1,
height=320,
width=576,
num_frames=24,
ENCOUNTERS=1
)
# print("output: ", output)
for i in output:
if isinstance(i, str):
print(i)
elif 'img' in i.keys():
for m in i['img']:
if isinstance(m, str):
print(m)
else:
m[0].save(f'./assets/images/{prompt}.jpg')
elif 'vid' in i.keys():
for idx, m in enumerate(i['vid']):
if isinstance(m, str):
print(m)
else:
video_path = export_to_video(video_frames=m, output_video_path=f'./assets/videos/{prompt}.mp4')
print("video_path: ", video_path)
elif 'aud' in i.keys():
for idx, m in enumerate(i['aud']):
if isinstance(m, str):
print(m)
else:
audio_path = f'./assets/audios/{prompt}.wav'
scipy.io.wavfile.write(audio_path, rate=16000, data=m)
print("video_path: ", audio_path)
else:
pass