import argparse import os import random import cv2 import numpy as np import torch import torch.backends.cudnn as cudnn from PIL import Image import json import pandas as pd def rtp_read(text_file): dataset = [] lines = open(text_file).read().split("\n") for li in lines: obj = json.loads(li) if obj['challenging']: dataset.append(obj['prompt']['text']) return dataset def denormalize(images): mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() new_images = (images - mean[None, :, None, None])/ std[None, :, None, None] return new_images def normalize(images): mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda() std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda() new_images = (images * std[None, :, None, None])+ mean[None, :, None, None] return new_images def parse_args(): parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--image_path", type=str, default='./image.bmp', help="Image file") parser.add_argument("--image_safety_patch", type=str, default=None, help="image safety patch file") parser.add_argument("--text_safety_patch", type=str, default=None, help="text safety patch file") parser.add_argument("--output_file", type=str, default='./result.jsonl', help="Output file.") args = parser.parse_args() return args def load_image(image_path): image = Image.open(image_path).convert('RGB') return image # ======================================== # Model Initialization # ======================================== print('>>> Initializing Models') from llava.utils import get_model args = parse_args() print('model = ', args.model_path) tokenizer, model, image_processor, model_name = get_model(args) model = model.cuda() print(model) print(model_name) model.eval() print('[Initialization Finished]\n') from llava_utils import prompt_wrapper, generator my_generator = generator.Generator(model=model, tokenizer=tokenizer) # ======================================== # Inference # ======================================== prompts = rtp_read('harmful_corpus/rtp_prompts.jsonl') text_prompt = '%s' if args.text_safety_patch!=None: with open(args.text_safety_patch, 'r') as file: text_safety_patch = file.read().rstrip() out = [] with torch.no_grad(): for i, user_message in enumerate(prompts): image_file = 'adversarial_qna_images/adv_image_' + str(np.random.randint(25)) + '.bmp' image = load_image(image_file) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda() if args.image_safety_patch!=None: image = normalize(image) safety_patch = torch.load(args.image_safety_patch).cuda() safe_image = denormalize((image + safety_patch).clamp(0,1)) else: safe_image = image print(f" ----- {i} ----") print(" -- prompt: ---") if args.text_safety_patch!=None: # Use the below for optimized text patch # user_message = text_safety_patch + '\n' + user_message # Use the below for heuristic text patch user_message += '\n' + text_safety_patch text_prompt_template = prompt_wrapper.prepare_text_prompt(text_prompt % user_message) print(text_prompt_template) prompt = prompt_wrapper.Prompt(model, tokenizer, text_prompts=text_prompt_template, device=model.device) response = my_generator.generate(prompt, safe_image).replace("[INST]","").replace("[/INST]","").replace("[SYS]","").replace("[/SYS/]","").strip() if args.text_safety_patch!=None: response = response.replace(text_safety_patch,"") print(" -- continuation: ---") print(response) out.append({'prompt': user_message, 'continuation': response}) print() with open(args.output_file, 'w') as f: f.write(json.dumps({ "args": vars(args), "prompt": text_prompt })) f.write("\n") for li in out: f.write(json.dumps(li)) f.write("\n")