Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from copy import deepcopy | |
from .base import Attacker, Empty | |
from torch.cuda import amp | |
from tqdm import tqdm | |
class PGD(Attacker): | |
def __init__(self, model, img_transform=(lambda x:x, lambda x:x), use_amp=False): | |
super().__init__(model, img_transform) | |
self.use_amp=use_amp | |
self.call_back=None | |
self.img_loader=None | |
self.img_hook=None | |
self.scaler = amp.GradScaler(enabled=use_amp) | |
def set_para(self, eps=8, alpha=lambda:8, iters=20, **kwargs): | |
super().set_para(eps=eps, alpha=alpha, iters=iters, **kwargs) | |
def set_call_back(self, call_back): | |
self.call_back=call_back | |
def set_img_loader(self, img_loader): | |
self.img_loader=img_loader | |
def step(self, images, labels, loss): | |
with amp.autocast(enabled=self.use_amp): | |
images.requires_grad = True | |
outputs = self.model(images).logits | |
self.model.zero_grad() | |
cost = loss(outputs, labels)#+outputs[2].view(-1)[0]*0+outputs[1].view(-1)[0]*0+outputs[0].view(-1)[0]*0 #support DDP | |
self.scaler.scale(cost).backward() | |
adv_images = (images + self.alpha() * images.grad.sign()).detach_() | |
eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps) | |
images = self.img_transform[0](torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=1).detach_()) | |
return images | |
def set_data(self, images, labels): | |
self.ori_images = deepcopy(images) | |
self.images = images | |
self.labels = labels | |
def __iter__(self): | |
self.atk_step=0 | |
return self | |
def __next__(self): | |
self.atk_step += 1 | |
if self.atk_step>self.iters: | |
raise StopIteration | |
with self.model.no_sync() if isinstance(self.model, nn.parallel.DistributedDataParallel) else Empty(): | |
self.model.eval() | |
self.images = self.forward(self, self.images, self.labels) | |
self.model.zero_grad() | |
self.model.train() | |
return self.ori_images, self.images.detach(), self.labels | |
def attack(self, images, labels): | |
#images = deepcopy(images) | |
self.ori_images = deepcopy(images) | |
for i in tqdm(range(self.iters)): | |
self.model.eval() | |
images = self.forward(self, images, labels) | |
self.model.zero_grad() | |
self.model.train() | |
if self.call_back: | |
self.call_back(self.ori_images, images.detach(), labels) | |
if self.img_hook is not None: | |
images=self.img_hook(self.ori_images, images.detach()) | |
return images |