llavaguard / text_safety_patch.py
Ahren09's picture
Upload 227 files
5ca4e86 verified
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torchvision.utils import save_image
from llava_utils import prompt_wrapper, text_defender
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument("--n_iters", type=int, default=50, help="specify the number of iterations for attack.")
parser.add_argument("--save_dir", type=str, default='outputs',
help="save directory")
parser.add_argument("--n_candidates", type=int, default=100,
help="n_candidates")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)
args = parser.parse_args()
return args
def setup_seeds(config):
seed = config.run_cfg.seed + get_rank()
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
cudnn.benchmark = False
cudnn.deterministic = True
# ========================================
# Model Initialization
# ========================================
print('>>> Initializing Models')
from llava.utils import get_model
args = parse_args()
print('model = ', args.model_path)
tokenizer, model, image_processor, model_name = get_model(args)
print(model.base_model)
model.eval()
print('[Initialization Finished]\n')
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
lines = open('harmful_corpus/harmful_strings.csv').read().split("\n")
targets = [li for li in lines if len(li)>0]
print(targets[0])
my_attacker = text_defender.Attacker(args, model,tokenizer, targets, device=model.device)
from llava_utils import prompt_wrapper
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token
text_prompt_template = prompt_wrapper.prepare_text_prompt('')
print(text_prompt_template)
prompt_segs = text_prompt_template.split('<image>') # each <ImageHere> corresponds to one image
print(prompt_segs)
seg_tokens = [
tokenizer(
seg, return_tensors="pt", add_special_tokens=i == 0).to(model.device).input_ids
# only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
embs = [model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings
mixed_embs = torch.cat(embs, dim=1)
offset = mixed_embs.shape[1]
print(offset)
adv_prompt = my_attacker.attack(text_prompt_template=text_prompt_template, offset=offset,
num_iter=args.n_iters, batch_size=8)