Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import random | |
import numpy as np | |
import torch | |
import torch.backends.cudnn as cudnn | |
from PIL import Image | |
from torchvision.utils import save_image | |
from llava_utils import prompt_wrapper, text_defender | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Demo") | |
parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze") | |
parser.add_argument("--model-base", type=str, default=None) | |
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") | |
parser.add_argument("--n_iters", type=int, default=50, help="specify the number of iterations for attack.") | |
parser.add_argument("--save_dir", type=str, default='outputs', | |
help="save directory") | |
parser.add_argument("--n_candidates", type=int, default=100, | |
help="n_candidates") | |
parser.add_argument( | |
"--options", | |
nargs="+", | |
help="override some settings in the used config, the key-value pair " | |
"in xxx=yyy format will be merged into config file (deprecate), " | |
"change to --cfg-options instead.", | |
) | |
args = parser.parse_args() | |
return args | |
def setup_seeds(config): | |
seed = config.run_cfg.seed + get_rank() | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
cudnn.benchmark = False | |
cudnn.deterministic = True | |
# ======================================== | |
# Model Initialization | |
# ======================================== | |
print('>>> Initializing Models') | |
from llava.utils import get_model | |
args = parse_args() | |
print('model = ', args.model_path) | |
tokenizer, model, image_processor, model_name = get_model(args) | |
print(model.base_model) | |
model.eval() | |
print('[Initialization Finished]\n') | |
if not os.path.exists(args.save_dir): | |
os.mkdir(args.save_dir) | |
lines = open('harmful_corpus/harmful_strings.csv').read().split("\n") | |
targets = [li for li in lines if len(li)>0] | |
print(targets[0]) | |
my_attacker = text_defender.Attacker(args, model,tokenizer, targets, device=model.device) | |
from llava_utils import prompt_wrapper | |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from llava.mm_utils import tokenizer_image_token | |
text_prompt_template = prompt_wrapper.prepare_text_prompt('') | |
print(text_prompt_template) | |
prompt_segs = text_prompt_template.split('<image>') # each <ImageHere> corresponds to one image | |
print(prompt_segs) | |
seg_tokens = [ | |
tokenizer( | |
seg, return_tensors="pt", add_special_tokens=i == 0).to(model.device).input_ids | |
# only add bos to the first seg | |
for i, seg in enumerate(prompt_segs) | |
] | |
embs = [model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings | |
mixed_embs = torch.cat(embs, dim=1) | |
offset = mixed_embs.shape[1] | |
print(offset) | |
adv_prompt = my_attacker.attack(text_prompt_template=text_prompt_template, offset=offset, | |
num_iter=args.n_iters, batch_size=8) | |