class Attacker: def __init__(self, model, img_transform=(lambda x:x, lambda x:x)): self.model = model # 必须是pytorch的model '''self.model.eval() for k, v in self.model.named_parameters(): v.requires_grad = False''' self.img_transform=img_transform self.forward = lambda attacker, images, labels: attacker.step(images, labels, attacker.loss) def set_para(self, **kwargs): for k,v in kwargs.items(): setattr(self, k,v) def set_forward(self, forward): self.forward=forward def step(self, images, labels, loss): pass def set_loss(self, loss): self.loss=loss def attack(self, images, labels): pass class Empty: def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): pass