demo-image-completion / gradio_imagecompletion.py
jaekookang
fix typo
afb22d5
raw
history blame
3.15 kB
'''Image Completion Demo (ImageGPT)
- Paper: https://arxiv.org/abs/2109.10282
- Code: https://huggingface.co/spaces/nielsr/imagegpt-completion
---
- 2021-12-10 first created
- examples changed
'''
from PIL import Image
import matplotlib.pyplot as plt
import os
import numpy as np
import requests
from glob import glob
import gradio as gr
from loguru import logger
import torch
from transformers import ImageGPTFeatureExtractor, ImageGPTForCausalImageModeling
# ========== Settings ==========
EXAMPLE_DIR = 'examples'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.jpg')))
# ========== Logger ==========
logger.add('app.log', mode='a')
logger.info('===== APP RESTARTED =====')
# ========== Models ==========
feature_extractor = ImageGPTFeatureExtractor.from_pretrained(
"openai/imagegpt-medium")
model = ImageGPTForCausalImageModeling.from_pretrained(
"openai/imagegpt-medium")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def process_image(image):
logger.info('--- image file received')
# prepare 7 images, shape (7, 1024)
batch_size = 7
encoding = feature_extractor([image for _ in range(batch_size)], return_tensors="pt")
# create primers
samples = encoding.pixel_values.numpy()
n_px = feature_extractor.size
clusters = feature_extractor.clusters
n_px_crop = 16
primers = samples.reshape(-1,n_px*n_px)[:,:n_px_crop*n_px] # crop top n_px_crop rows. These will be the conditioning tokens
# get conditioned image (from first primer tensor), padded with black pixels to be 32x32
primers_img = np.reshape(np.rint(127.5 * (clusters[primers[0]] + 1.0)), [n_px_crop,n_px, 3]).astype(np.uint8)
primers_img = np.pad(primers_img, pad_width=((0,16), (0,0), (0,0)), mode="constant")
# generate (no beam search)
context = np.concatenate((np.full((batch_size, 1), model.config.vocab_size - 1), primers), axis=1)
context = torch.tensor(context).to(device)
output = model.generate(input_ids=context, max_length=n_px*n_px + 1, temperature=1.0, do_sample=True, top_k=40)
# decode back to images (convert color cluster tokens back to pixels)
samples = output[:,1:].cpu().detach().numpy()
samples_img = [np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [n_px, n_px, 3]).astype(np.uint8) for s in samples]
samples_img = [primers_img] + samples_img
# stack images horizontally
row1 = np.hstack(samples_img[:4])
row2 = np.hstack(samples_img[4:])
result = np.vstack([row1, row2])
# return as PIL Image
completion = Image.fromarray(result)
return completion
iface = gr.Interface(
process_image,
title="이미지의 절반을 지우고 절반을 채워 넣어주는 Image Completion 데모입니다 (ImageGPT)",
description='주어진 이미지의 절반 아래를 AI가 채워 넣어줍니다',
inputs=gr.inputs.Image(type="pil", label='인풋 이미지'),
outputs=gr.outputs.Image(type="pil", label='AI가 그린 결과'),
examples=examples,
enable_queue=True,
article='<p style="text-align:center">i-Scream AI</p>',
)
iface.launch()