dzy7e
init
49d1787
import torch
from torch import nn
from copy import deepcopy
from .base import Attacker, Empty
from torch.cuda import amp
from tqdm import tqdm
class PGD(Attacker):
def __init__(self, model, img_transform=(lambda x:x, lambda x:x), use_amp=False):
super().__init__(model, img_transform)
self.use_amp=use_amp
self.call_back=None
self.img_loader=None
self.img_hook=None
self.scaler = amp.GradScaler(enabled=use_amp)
def set_para(self, eps=8, alpha=lambda:8, iters=20, **kwargs):
super().set_para(eps=eps, alpha=alpha, iters=iters, **kwargs)
def set_call_back(self, call_back):
self.call_back=call_back
def set_img_loader(self, img_loader):
self.img_loader=img_loader
def step(self, images, labels, loss):
with amp.autocast(enabled=self.use_amp):
images.requires_grad = True
outputs = self.model(images).logits
self.model.zero_grad()
cost = loss(outputs, labels)#+outputs[2].view(-1)[0]*0+outputs[1].view(-1)[0]*0+outputs[0].view(-1)[0]*0 #support DDP
self.scaler.scale(cost).backward()
adv_images = (images + self.alpha() * images.grad.sign()).detach_()
eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps)
images = self.img_transform[0](torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=1).detach_())
return images
def set_data(self, images, labels):
self.ori_images = deepcopy(images)
self.images = images
self.labels = labels
def __iter__(self):
self.atk_step=0
return self
def __next__(self):
self.atk_step += 1
if self.atk_step>self.iters:
raise StopIteration
with self.model.no_sync() if isinstance(self.model, nn.parallel.DistributedDataParallel) else Empty():
self.model.eval()
self.images = self.forward(self, self.images, self.labels)
self.model.zero_grad()
self.model.train()
return self.ori_images, self.images.detach(), self.labels
def attack(self, images, labels):
#images = deepcopy(images)
self.ori_images = deepcopy(images)
for i in tqdm(range(self.iters)):
self.model.eval()
images = self.forward(self, images, labels)
self.model.zero_grad()
self.model.train()
if self.call_back:
self.call_back(self.ori_images, images.detach(), labels)
if self.img_hook is not None:
images=self.img_hook(self.ori_images, images.detach())
return images