Spaces:
Runtime error
Runtime error
import json | |
import os.path as osp | |
import gradio as gr | |
import numpy as np | |
import torch | |
from llava.mm_utils import get_model_name_from_path | |
from llava.model.builder import load_pretrained_model | |
from llava_utils import prompt_wrapper, generator | |
from utils import normalize, denormalize, load_image | |
# BASE_DIR = "/workingdir/soh337/llavaguard" | |
BASE_DIR = "/Users/ahren/Workspace/Multimodal/llavaguard" | |
IMAGE_PATH = 'unconstrained_attack_images/adversarial_' | |
MODEL_PATH = "/workingdir/models_hf/liuhaotian/llava-v1.5-13b" | |
TEXT_SAFETY_PATCHES = { | |
"optimized": "text_patch_optimized", | |
"heuristic": "text_patch_heuristic" | |
} | |
IMAGE_SAFETY_PATCHES = { | |
"default": "safety_patch.bmp" | |
} | |
def rtp_read(text_file): | |
dataset = [] | |
lines = open(text_file).read().split("\n") | |
for li in lines: | |
obj = json.loads(li) | |
if obj['challenging']: | |
dataset.append(obj['prompt']['text']) | |
return dataset | |
model = loaded_model_name = None | |
print('>>> Initializing Models') | |
prompts = rtp_read(osp.join(BASE_DIR, 'harmful_corpus/rtp_prompts.jsonl')) | |
out = [] | |
def generate_answer(user_message: str, image_path: str, requested_model_name: str, | |
image_safety_patch_type: str, text_safety_patch_type: str, | |
): | |
global loaded_model_name | |
text_safety_patch = TEXT_SAFETY_PATCHES[text_safety_patch_type] | |
image_safety_patch = IMAGE_SAFETY_PATCHES[image_safety_patch_type] | |
if requested_model_name == "LLaVA": | |
if requested_model_name == loaded_model_name: | |
print(f"{requested_model_name} model already loaded.") | |
else: | |
print(f"Loading {requested_model_name} model ... ") | |
model_name = get_model_name_from_path(MODEL_PATH) | |
tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, | |
model_name) | |
loaded_model_name = requested_model_name | |
my_generator = generator.Generator(model=model, tokenizer=tokenizer) | |
# load a randomly-sampled unconstrained attack image as Image object | |
image = load_image(image_path) | |
# transform the image using the visual encoder (CLIP) of LLaVA 1.5; the processed image size would be PyTorch tensor whose shape is (336,336). | |
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda() | |
if image_safety_patch != None: | |
# make the image pixel values between (0,1) | |
image = normalize(image) | |
# load the safety patch tensor whose values are (0,1) | |
safety_patch = torch.load(image_safety_patch).cuda() | |
# apply the safety patch to the input image, clamp it between (0,1) and denormalize it to the original pixel values | |
safe_image = denormalize((image + safety_patch).clamp(0, 1)) | |
# make sure the image value is between (0,1) | |
print(torch.min(image), torch.max(image), torch.min(safe_image), torch.max(safe_image)) | |
else: | |
safe_image = image | |
model.eval() | |
if text_safety_patch != None: | |
# use the below for optimal text safety patch | |
# user_message = text_safety_patch + '\n' + user_message | |
# use the below for heuristic text safety patch | |
user_message += '\n' + text_safety_patch | |
text_prompt_template = prompt_wrapper.prepare_text_prompt(text_prompt % user_message) | |
print(text_prompt_template) | |
prompt = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template, device=model.device) | |
response = my_generator.generate(prompt, safe_image).replace("[INST]", "").replace("[/INST]", "").replace( | |
"[SYS]", "").replace("[/SYS/]", "").strip() | |
if text_safety_patch != None: | |
response = response.replace(text_safety_patch, "") | |
print(" -- continuation: ---") | |
print(response) | |
out.append({'prompt': user_message, 'continuation': response}) | |
def get_list_of_examples(): | |
global rtp | |
examples = [] | |
for i, prompt in enumerate(prompts[:3]): # Use the first 3 prompts for simplicity | |
image_num = np.random.randint(25) # Randomly select an image number | |
image_path = f'{IMAGE_PATH}{image_num}.bmp' | |
examples.append( | |
[image_path, prompt] | |
) | |
return examples | |
css = """#col-container {max-width: 90%; margin-left: auto; margin-right: auto; display: flex; flex-direction: column;} | |
#header {text-align: center;} | |
#col-chatbox {flex: 1; max-height: min(750px, 100%);} | |
#label {font-size: 2em; padding: 0.5em; margin: 0;} | |
.message {font-size: 1.2em;} | |
.message-wrap {max-height: min(700px, 100vh);} | |
""" | |
def get_empty_state(): | |
# TODO: Not sure what this means | |
return gr.State({"arena": None}) | |
examples = get_list_of_examples() | |
# Define a function to update inputs based on selected example | |
def update_inputs(example_id): | |
selected_example = examples[int(example_id)] | |
return selected_example['image_path'], selected_example['text'] | |
model_selector, image_patch_selector, text_patch_selector = None, None, None | |
def process_text_and_image(user_message: str, image_path: str): | |
global model_selector, image_patch_selector, text_patch_selector | |
print(f"User Message: {user_message}") | |
# print(f"Text Safety Patch: {safety_patch}") | |
print(f"Image Path: {image_path}") | |
print(model_selector.value) | |
# generate_answer(user_message, image_path, "LLaVA", "heuristic", "default") | |
generate_answer(user_message, image_path, model_selector.value, image_patch_selector.value, text_patch_selector.value) | |
with gr.Blocks(css=css) as demo: | |
state = get_empty_state() | |
all_components = [] | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
"""# 🦙LLaVAGuard🔥<br> | |
Safeguarding your Multimodal LLM | |
**[Project Homepage](#)**""", | |
elem_id="header", | |
) | |
# example_selector = gr.Dropdown(choices=[f"Example {i}" for i, e in enumerate(examples)], | |
# label="Select an Example") | |
with gr.Row(): | |
model_selector = gr.Dropdown(choices=["LLaVA"], label="Model", info="Select Model", value="LLaVA") | |
image_patch_selector = gr.Dropdown(choices=["default"], label="Image Patch", info="Select Image Safety " | |
"Patch", value="default") | |
text_patch_selector = gr.Dropdown(choices=["heuristic", "optimized"], label="Text Patch", info="Select " | |
"Text " | |
"Safety " | |
"Patch", | |
value="heuristic") | |
image_and_text_uploader = gr.Interface( | |
fn=process_text_and_image, | |
inputs=[gr.Image(type="pil", label="Upload your image", interactive=True), | |
gr.Textbox(placeholder="Input a question", label="Your Question"), | |
], | |
examples=examples, | |
outputs=['text']) | |
# # Set the action for the generate button | |
# @demo.events(generate_button) | |
# def handle_generation(image, question, model, image_patch, text_patch): | |
# generate_answer(question, image, model, text_patch, image_patch) | |
# Launch the demo | |
demo.launch() | |