File size: 2,502 Bytes
b7ebb88
03f7bd7
f93986b
03f7bd7
 
 
b7ebb88
37ebd45
 
 
 
 
 
03f7bd7
f93fa3d
 
 
 
 
f93986b
07b1c90
353541c
37ebd45
03f7bd7
 
37ebd45
03f7bd7
 
a4244e1
 
f93986b
3c13f2b
f93986b
a4244e1
5c39195
03f7bd7
 
 
 
fa12e38
f93986b
fa12e38
f93986b
 
f93fa3d
f93986b
03f7bd7
 
37ebd45
03f7bd7
 
 
 
 
 
b7ebb88
 
 
68fa56c
03f7bd7
f93986b
 
 
b222813
68fa56c
93f9b8b
b7ebb88
 
37ebd45
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import torch
from torch.nn import BCEWithLogitsLoss
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image

# Load model and feature extractor outside the function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()

def get_encoder_activations(x):
    encoder_output = model.vit(x)
    final_activations = encoder_output.last_hidden_state[:,0,:]
    return final_activations

def process_image(input_image, learning_rate, iterations, n_targets, seed):
    if input_image is None:
        return None

    image = input_image.convert('RGB')
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    pixel_values.requires_grad_(True)

    
    torch.manual_seed(int(seed))
    random_one_logits = torch.zeros(1000)
    random_one_logits[torch.randperm(1000)[:int(n_targets)]] = 1
    random_one_logits = random_one_logits.to(pixel_values.device)

    for iteration in range(int(iterations)):
        model.zero_grad()
        if pixel_values.grad is not None:
            pixel_values.grad.data.zero_()

        final_activations = get_encoder_activations(pixel_values.to(device))
    
        logits = model.classifier(final_activations[0]).to(device)
    
        original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits)
    
        original_loss.backward()

        with torch.no_grad():
            pixel_values.data += learning_rate * pixel_values.grad.data
        pixel_values.data = torch.clamp(pixel_values.data, -1, 1)

    updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
    updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)

    return updated_pixel_values_np

iface = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil"), 
        gr.Number(value=1.0, minimum=0, label="Learning Rate"),
        gr.Number(value=2, minimum=1, label="Iterations"),
        gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
        gr.Number(value=420, minimum=0, label="Seed"),
    ],
    outputs=[gr.Image(type="numpy", label="Dreamed Image")]
)

iface.launch()