SoggyKiwi commited on
Commit
f93986b
1 Parent(s): 8c65b05

remove tv loss, using BCE loss now

Browse files
Files changed (1) hide show
  1. app.py +14 -17
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  import numpy as np
4
  from transformers import ViTImageProcessor, ViTForImageClassification
5
  from PIL import Image
@@ -16,12 +17,7 @@ def get_encoder_activations(x):
16
  final_activations = encoder_output.last_hidden_state[:,0,:]
17
  return final_activations
18
 
19
- def total_variation_loss(img):
20
- pixel_dif1 = img[:, :, 1:, :] - img[:, :, :-1, :]
21
- pixel_dif2 = img[:, :, :, 1:] - img[:, :, :, :-1]
22
- return (torch.sum(torch.abs(pixel_dif1)) + torch.sum(torch.abs(pixel_dif2)))
23
-
24
- def process_image(input_image, learning_rate, tv_weight, iterations, n_targets, seed):
25
  if input_image is None:
26
  return None
27
 
@@ -32,20 +28,22 @@ def process_image(input_image, learning_rate, tv_weight, iterations, n_targets,
32
 
33
 
34
  torch.manual_seed(int(seed))
35
- random_indices = torch.randperm(1000)[:int(n_targets)].to(pixel_values.device)
 
 
36
 
37
  for iteration in range(int(iterations)):
38
  model.zero_grad()
39
  if pixel_values.grad is not None:
40
  pixel_values.grad.data.zero_()
41
 
42
- final_activations = get_encoder_activations(pixel_values)
43
- logits = model.classifier(final_activations[0])
 
 
 
44
 
45
- original_loss = logits[random_indices].sum()
46
- tv_loss = total_variation_loss(pixel_values)
47
- total_loss = original_loss - tv_weight * tv_loss
48
- total_loss.backward()
49
 
50
  with torch.no_grad():
51
  pixel_values.data += learning_rate * pixel_values.grad.data
@@ -60,11 +58,10 @@ iface = gr.Interface(
60
  fn=process_image,
61
  inputs=[
62
  gr.Image(type="pil"),
63
- gr.Number(value=16.0, minimum=0, label="Learning Rate"),
64
- gr.Number(value=0.0001, label="Total Variation Loss"),
65
- gr.Number(value=4, minimum=1, label="Iterations"),
66
  gr.Number(value=420, minimum=0, label="Seed"),
67
- gr.Number(value=500, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
68
  ],
69
  outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
70
  )
 
1
  import gradio as gr
2
  import torch
3
+ from torch.nn import BCEWithLogitsLoss
4
  import numpy as np
5
  from transformers import ViTImageProcessor, ViTForImageClassification
6
  from PIL import Image
 
17
  final_activations = encoder_output.last_hidden_state[:,0,:]
18
  return final_activations
19
 
20
+ def process_image(input_image, learning_rate, iterations, n_targets, seed):
 
 
 
 
 
21
  if input_image is None:
22
  return None
23
 
 
28
 
29
 
30
  torch.manual_seed(int(seed))
31
+ random_one_logits = torch.zeros(1000)
32
+ random_one_logits[torch.randperm(1000)[:n_targets]] = 1
33
+ random_one_logits = random_one_logits.to(pixel_values.device)
34
 
35
  for iteration in range(int(iterations)):
36
  model.zero_grad()
37
  if pixel_values.grad is not None:
38
  pixel_values.grad.data.zero_()
39
 
40
+ final_activations = get_encoder_activations(pixel_values.to('cuda'))
41
+
42
+ logits = model.classifier(final_activations[0]).to(pixel_values.device)
43
+
44
+ original_loss = BCEWithLogitsLoss(reduction='sum')(logits,random_one_logits)
45
 
46
+ original_loss.backward()
 
 
 
47
 
48
  with torch.no_grad():
49
  pixel_values.data += learning_rate * pixel_values.grad.data
 
58
  fn=process_image,
59
  inputs=[
60
  gr.Image(type="pil"),
61
+ gr.Number(value=1.0, minimum=0, label="Learning Rate"),
62
+ gr.Number(value=2, minimum=1, label="Iterations"),
 
63
  gr.Number(value=420, minimum=0, label="Seed"),
64
+ gr.Number(value=250, minimum=1, maximum=1000, label="Number of Random Target Class Activations to Maximise"),
65
  ],
66
  outputs=[gr.Image(type="numpy", label="ViT-Dreamed Image")]
67
  )