import argparse import os import random import numpy as np import torch import torch.backends.cudnn as cudnn from PIL import Image from torchvision.utils import save_image from llava_utils import prompt_wrapper, text_defender def parse_args(): parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze") parser.add_argument("--model-base", type=str, default=None) parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") parser.add_argument("--n_iters", type=int, default=50, help="specify the number of iterations for attack.") parser.add_argument("--save_dir", type=str, default='outputs', help="save directory") parser.add_argument("--n_candidates", type=int, default=100, help="n_candidates") parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) args = parser.parse_args() return args def setup_seeds(config): seed = config.run_cfg.seed + get_rank() random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) cudnn.benchmark = False cudnn.deterministic = True # ======================================== # Model Initialization # ======================================== print('>>> Initializing Models') from llava.utils import get_model args = parse_args() print('model = ', args.model_path) tokenizer, model, image_processor, model_name = get_model(args) print(model.base_model) model.eval() print('[Initialization Finished]\n') if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) lines = open('harmful_corpus/harmful_strings.csv').read().split("\n") targets = [li for li in lines if len(li)>0] print(targets[0]) my_attacker = text_defender.Attacker(args, model,tokenizer, targets, device=model.device) from llava_utils import prompt_wrapper from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.mm_utils import tokenizer_image_token text_prompt_template = prompt_wrapper.prepare_text_prompt('') print(text_prompt_template) prompt_segs = text_prompt_template.split('') # each corresponds to one image print(prompt_segs) seg_tokens = [ tokenizer( seg, return_tensors="pt", add_special_tokens=i == 0).to(model.device).input_ids # only add bos to the first seg for i, seg in enumerate(prompt_segs) ] embs = [model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings mixed_embs = torch.cat(embs, dim=1) offset = mixed_embs.shape[1] print(offset) adv_prompt = my_attacker.attack(text_prompt_template=text_prompt_template, offset=offset, num_iter=args.n_iters, batch_size=8)