import gradio as gr import torch import numpy as np from transformers import ViTImageProcessor, ViTForImageClassification from PIL import Image # Load model and feature extractor outside the function device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384') model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384') model.to(device) model.eval() def get_encoder_activations(x): encoder_output = model.vit(x) final_activations = encoder_output.last_hidden_state[:,0,:] return final_activations def total_variation_loss(img): pixel_dif1 = img[:, :, 1:, :] - img[:, :, :-1, :] pixel_dif2 = img[:, :, :, 1:] - img[:, :, :, :-1] return (torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2))) def process_image(input_image, learning_rate, tv_weight, iterations, n_targets, seed): if input_image is None: return None image = input_image.convert('RGB') pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) pixel_values.requires_grad_(True) torch.manual_seed(int(seed)) random_indices = torch.randperm(1000)[:int(n_targets)].to(pixel_values.device) for iteration in range(int(iterations)): model.zero_grad() if pixel_values.grad is not None: pixel_values.grad.data.zero_() final_activations = get_encoder_activations(pixel_values) logits = model.classifier(final_activations[0]) original_loss = logits[random_indices].sum() tv_loss = total_variation_loss(pixel_values) total_loss = original_loss - tv_weight * tv_loss total_loss.backward() with torch.no_grad(): pixel_values.data += learning_rate * pixel_values.grad.data pixel_values.data = torch.clamp(pixel_values.data, -1, 1) updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5 updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8) return updated_pixel_values_np iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="pil"), gr.Number(value=16.0, minimum=0, label="Learning Rate"), gr.Number(value=0.0001, label="Total Variation Loss"), gr.Number(value=4, minimum=1, label="Iterations"), gr.Number(value=420, minimum=0, label="Seed"), gr.Number(value=500, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"), ], outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")] ) iface.launch()