Spaces:
Paused
Paused
import torch | |
from transformers import TextStreamer | |
import os | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation")) | |
from llava.constants import IMAGE_TOKEN_INDEX | |
from llava.conversation import conv_templates, SeparatorStyle | |
from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
import shutil | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
title_markdown = (""" | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<div> | |
<h1 >VLM-RLAIF: Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback (ACL 2024 Oral) </h1> | |
<h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5> | |
</div> | |
</div> | |
<div align="center"> | |
<div style="display:flex; gap: 0.25rem;" align="center"> | |
<a href='https://github.com/yonseivnl/vlm-rlaif'><img src='https://img.shields.io/badge/Github-Code-blue'></a> | |
<a href="https://arxiv.org/abs/2402.03746"><img src="https://img.shields.io/badge/Paper-arxiv-green"></a> | |
</div> | |
</div> | |
""") | |
block_css = """ | |
#buttons button { | |
min-width: min(120px,100%); | |
} | |
""" | |
tos_markdown = ("""""") | |
learn_more_markdown = (""" | |
### License | |
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA. | |
""") | |
class Chat: | |
def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None): | |
disable_torch_init() | |
model_name = get_model_name_from_path(model_path) | |
is_rlhf_checkpoint = 'rlhf' in model_path.lower() | |
print("MODEL_PATH", model_path) | |
print("RLHF Checkpoint: ", is_rlhf_checkpoint) | |
if not model_base or model_base == "none": model_base = None | |
if is_rlhf_checkpoint: | |
model_name = model_path | |
print("Config?", os.path.exists(os.path.join(model_path, "config.json"))) | |
if not os.path.exists(os.path.join(model_path, "config.json")): | |
print("Copying") | |
shutil.copy(os.path.join(model_base, "config.json"), os.path.join(model_path, "config.json")) # Copy SFT model's config -> to RLHF folder | |
print("Listed", os.listdir(model_path)) | |
print("Copying done") | |
self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, False, False, device=device) | |
self.image_processor = image_processor | |
self.conv_mode = conv_mode | |
self.conv = conv_templates[conv_mode].copy() | |
self.device = self.model.device | |
print(self.model) | |
def get_prompt(self, qs, state): | |
state.append_message(state.roles[0], qs) | |
state.append_message(state.roles[1], None) | |
return state | |
def _get_latest_prompt(self, state): | |
new_state = state.copy() | |
new_state.messages = state.messages[-2:] | |
return new_state | |
# def generate(self, images_tensor: list, prompt: str, first_run: bool, state): | |
def generate(self, images_tensor: torch.Tensor, prompt: str, first_run: bool, state): | |
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor | |
state = self.get_prompt(prompt, state) | |
# prompt = state.get_prompt() | |
latest_state = self._get_latest_prompt(state) | |
prompt = latest_state.get_prompt() | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) | |
temperature = 0.2 | |
max_new_tokens = 1024 | |
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 | |
keywords = [stop_str] | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
print(prompt, input_ids.shape, images_tensor.shape) | |
# print(images_tensor) | |
with torch.inference_mode(): | |
output_ids = model.generate( | |
input_ids, | |
images=images_tensor, | |
do_sample=True, | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
streamer=streamer, | |
use_cache=True, | |
stopping_criteria=[stopping_criteria]) | |
input_token_len = input_ids.shape[1] | |
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
if n_diff_input_output > 0: | |
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
outputs = outputs.strip() | |
outputs = outputs.replace("QA_GT_caption_based_noisy", "") | |
if outputs.endswith(stop_str): | |
outputs = outputs[:-len(stop_str)] | |
outputs = outputs.strip() | |
print('response', outputs) | |
return outputs, state | |