File size: 3,058 Bytes
5ca4e86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import argparse
import os
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torchvision.utils import save_image
from llava_utils import prompt_wrapper, text_defender


def parse_args():

    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--model-path", type=str, default="ckpts/llava_llama_2_13b_chat_freeze")
    parser.add_argument("--model-base", type=str, default=None) 
    parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--n_iters", type=int, default=50, help="specify the number of iterations for attack.")

    parser.add_argument("--save_dir", type=str, default='outputs',
                        help="save directory")
    parser.add_argument("--n_candidates", type=int, default=100,
                        help="n_candidates")

    parser.add_argument(
        "--options",
        nargs="+",
        help="override some settings in the used config, the key-value pair "
             "in xxx=yyy format will be merged into config file (deprecate), "
             "change to --cfg-options instead.",
    )
    args = parser.parse_args()
    return args


def setup_seeds(config):
    seed = config.run_cfg.seed + get_rank()

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    cudnn.benchmark = False
    cudnn.deterministic = True


# ========================================
#             Model Initialization
# ========================================


print('>>> Initializing Models')

from llava.utils import get_model
args = parse_args()

print('model = ', args.model_path)

tokenizer, model, image_processor, model_name = get_model(args)
print(model.base_model)
model.eval()
print('[Initialization Finished]\n')


if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

lines = open('harmful_corpus/harmful_strings.csv').read().split("\n")
targets = [li for li in lines if len(li)>0]
print(targets[0])

my_attacker = text_defender.Attacker(args, model,tokenizer, targets, device=model.device)

from llava_utils import prompt_wrapper
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token

text_prompt_template = prompt_wrapper.prepare_text_prompt('')
print(text_prompt_template)

prompt_segs = text_prompt_template.split('<image>')  # each <ImageHere> corresponds to one image
print(prompt_segs)
seg_tokens = [
     tokenizer(
         seg, return_tensors="pt", add_special_tokens=i == 0).to(model.device).input_ids
     # only add bos to the first seg
     for i, seg in enumerate(prompt_segs)
 ]
embs = [model.model.embed_tokens(seg_t) for seg_t in seg_tokens] # text to embeddings
mixed_embs = torch.cat(embs, dim=1)
offset = mixed_embs.shape[1]
print(offset)

adv_prompt = my_attacker.attack(text_prompt_template=text_prompt_template, offset=offset,
                                    num_iter=args.n_iters, batch_size=8)