DeIT-Dreamer / app.py
SoggyKiwi's picture
lmao lets goooo
03f7bd7
raw
history blame
No virus
1.84 kB
import gradio as gr
import torch
import numpy as np
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
def process_image(input_image, learning_rate, iterations):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-large-patch32-384')
model = ViTForImageClassification.from_pretrained('google/vit-large-patch32-384')
model.to(device)
model.eval()
def get_encoder_activations(x):
encoder_output = model.vit(x)
final_activations = encoder_output.last_hidden_state
return final_activations
image = input_image.convert('RGB')
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
pixel_values.to(device)
pixel_values.requires_grad_(True)
for iteration in range(iterations.value):
model.zero_grad()
if pixel_values.grad is not None:
pixel_values.grad.data.zero_()
final_activations = get_encoder_activations(pixel_values.to('cuda'))
target_sum = final_activations.sum()
target_sum.backward()
with torch.no_grad():
pixel_values.data += learning_rate.value * pixel_values.grad.data
pixel_values.data = torch.clamp(pixel_values.data, -1, 1)
updated_pixel_values_np = 127.5 + pixel_values.squeeze().permute(1, 2, 0).detach().cpu() * 127.5
updated_pixel_values_np = updated_pixel_values_np.numpy().astype(np.uint8)
return updated_pixel_values_np
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil"),
gr.Number(value=0.01, label="Learning Rate"),
gr.Number(value=1, label="Iterations")
],
outputs=gr.Image(type="numpy", label="Processed Image")
)
iface.launch()