llavaguard / llava_utils /visual_defender.py
Ahren09's picture
Upload 227 files
5ca4e86 verified
raw
history blame
6.33 kB
import torch
from tqdm import tqdm
import random
from llava_utils import prompt_wrapper
from torchvision.utils import save_image
import copy
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import MultiCursor
import seaborn as sns
def denormalize(images):
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
new_images = (images - mean[None, :, None, None])/ std[None, :, None, None]
return new_images
def normalize(images):
mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
new_images = (images * std[None, :, None, None])+ mean[None, :, None, None]
return new_images
class Defender:
def __init__(self, args, model, tokenizer, pos_targets, neg_targets, device='cuda:0', is_rtp=False, image_processor=None):
self.args = args
self.model = model
self.tokenizer= tokenizer
self.device = device
self.is_rtp = is_rtp
self.pos_targets = pos_targets
self.neg_targets = neg_targets
self.num_targets = len(pos_targets)
self.loss_buffer = []
# freeze and set to eval model:
self.model.eval()
self.model.requires_grad_(False)
self.image_processor = image_processor
def defense_constrained(self, text_prompt, img, batch_size = 4, num_iter=2000, alpha=1/255, epsilon = 128/255 ):
print('>>> batch_size:', batch_size)
adv_noise = torch.rand_like(img[0]).cuda() * 2 * epsilon - epsilon
x = normalize(img).clone().cuda()
adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
adv_noise = adv_noise.cuda()
neg_prompt = prompt_wrapper.Prompt(self.model, self.tokenizer, text_prompts=text_prompt, device=self.device)
adv_noise.requires_grad_(True)
adv_noise.retain_grad()
for t in tqdm(range(num_iter + 1)):
neg_batch_targets = random.sample(self.neg_targets, batch_size)
target_loss = 0
x_adv = x + adv_noise
x_adv = denormalize(x_adv)
target_loss -= self.attack_loss(neg_prompt,x_adv,neg_batch_targets)
target_loss.backward()
adv_noise.data = (adv_noise.data - alpha * adv_noise.grad.detach().sign()).clamp(-epsilon, epsilon)
adv_noise.data = (adv_noise.data + x.data).clamp(0, 1) - x.data
adv_noise.grad.zero_()
self.model.zero_grad()
self.loss_buffer.append(target_loss.item())
print("target_loss: %f" % (
target_loss.item())
)
if t % 20 == 0:
self.plot_loss()
if t % 100 == 0:
safety_patch = adv_noise.detach().cpu().squeeze(0)
#if you want to save the image safety patch
# save_image(safety_patch, '%s/safety_patch_temp_%d.bmp' % (self.args.save_dir, t))
return safety_patch
def plot_loss(self):
sns.set_theme()
num_iters = len(self.loss_buffer)
x_ticks = list(range(0, num_iters))
# Plot and label the training and validation loss values
plt.plot(x_ticks, self.loss_buffer, label='Target Loss')
# Add in a title and axes labels
plt.title('Loss Plot')
plt.xlabel('Iters')
plt.ylabel('Loss')
# Display the plot
plt.legend(loc='best')
plt.savefig('%s/loss_curve.png' % (self.args.save_dir))
plt.clf()
torch.save(self.loss_buffer, '%s/loss' % (self.args.save_dir))
def attack_loss(self, prompts, images, targets):
context_length = prompts.context_length
context_input_ids = prompts.input_ids
batch_size = len(targets)
images = images.repeat(batch_size, 1, 1, 1)
if len(context_input_ids) == 1:
context_length = context_length * batch_size
context_input_ids = context_input_ids * batch_size
assert len(context_input_ids) == len(targets), f"Unmathced batch size of prompts and targets {len(context_input_ids)} != {len(targets)}"
tokens = [ torch.as_tensor([item[1:]]).cuda() for item in self.tokenizer(targets).input_ids] # get rid of the default <bos> in targets tokenization.
seq_tokens_length = []
labels = []
input_ids = []
for i, item in enumerate(tokens):
L = item.shape[1] + context_length[i]
seq_tokens_length.append(L)
context_mask = torch.full([1, context_length[i]], -100,
dtype=tokens[0].dtype,
device=tokens[0].device)
labels.append( torch.cat( [context_mask, item], dim=1 ) )
input_ids.append( torch.cat( [context_input_ids[i], item], dim=1 ) )
# padding token
pad = torch.full([1, 1], 0,
dtype=tokens[0].dtype,
device=tokens[0].device).cuda() # it does not matter ... Anyway will be masked out from attention...
max_length = max(seq_tokens_length)
attention_mask = []
for i in range(batch_size):
# padding to align the length
num_to_pad = max_length - seq_tokens_length[i]
padding_mask = (
torch.full([1, num_to_pad], -100,
dtype=torch.long,
device=self.device)
)
labels[i] = torch.cat( [labels[i], padding_mask], dim=1 )
input_ids[i] = torch.cat( [input_ids[i],
pad.repeat(1, num_to_pad)], dim=1 )
attention_mask.append( torch.LongTensor( [ [1]* (seq_tokens_length[i]) + [0]*num_to_pad ] ) )
labels = torch.cat( labels, dim=0 ).cuda()
input_ids = torch.cat( input_ids, dim=0 ).cuda()
attention_mask = torch.cat(attention_mask, dim=0).cuda()
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
labels=labels,
images=images.half(),
)
loss = outputs.loss
return loss