SoggyKiwi commited on
Commit
03f7bd7
1 Parent(s): 68fa56c

lmao lets goooo

Browse files
Files changed (2) hide show
  1. app.py +40 -5
  2. requirements.txt +5 -1
app.py CHANGED
@@ -1,18 +1,53 @@
1
  import gradio as gr
 
 
 
 
2
 
3
  def process_image(input_image, learning_rate, iterations):
4
- # Your image processing logic here
5
- output_image = input_image # Placeholder for your image processing
6
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  iface = gr.Interface(
9
  fn=process_image,
10
  inputs=[
11
- gr.Image(),
12
  gr.Number(value=0.01, label="Learning Rate"),
13
  gr.Number(value=1, label="Iterations")
14
  ],
15
- outputs=gr.Image()
16
  )
17
 
18
  iface.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import ViTImageProcessor, ViTForImageClassification
5
+ from PIL import Image
6
 
7
  def process_image(input_image, learning_rate, iterations):
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
11
+ model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
12
+ model.to(device)
13
+ model.eval()
14
+
15
+ def get_encoder_activations(x):
16
+ encoder_output = model.vit(x)
17
+ final_activations = encoder_output.last_hidden_state
18
+ return final_activations
19
+
20
+ image = input_image.convert('RGB')
21
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
22
+ pixel_values.to(device)
23
+ pixel_values.requires_grad_(True)
24
+
25
+ for iteration in range(iterations.value):
26
+ model.zero_grad()
27
+ if pixel_values.grad is not None:
28
+ pixel_values.grad.data.zero_()
29
+
30
+ final_activations = get_encoder_activations(pixel_values.to('cuda'))
31
+ target_sum = final_activations.sum()
32
+ target_sum.backward()
33
+
34
+ with torch.no_grad():
35
+ pixel_values.data += learning_rate.value * pixel_values.grad.data
36
+ pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
37
+
38
+ updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
39
+ updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)
40
+
41
+ return updated_pixel_values_np
42
 
43
  iface = gr.Interface(
44
  fn=process_image,
45
  inputs=[
46
+ gr.Image(type="pil"),
47
  gr.Number(value=0.01, label="Learning Rate"),
48
  gr.Number(value=1, label="Iterations")
49
  ],
50
+ outputs=gr.Image(type="numpy", label="Processed Image")
51
  )
52
 
53
  iface.launch()
requirements.txt CHANGED
@@ -1 +1,5 @@
1
- gradio
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ numpy
4
+ transformers
5
+ PIL