lotr / app.py
mikegarts's picture
change to openjourney model
4a7427e
import time
import os
import PIL
import gradio as gr
import torch
import transformers
# from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from transformers import pipeline
from diffusers import StableDiffusionPipeline
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
# model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "CompVis/stable-diffusion-v1-4"
model_id = "prompthero/openjourney"
has_cuda = torch.cuda.is_available()
if has_cuda:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
device = "cuda"
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN)
device = "cpu"
pipe.to(device)
def safety_checker(images, clip_input):
return images, False
pipe.safety_checker = safety_checker
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
text_generation_pipe = pipeline("text-generation", model=SAVED_CHECKPOINT)
summarizer = pipeline("summarization")
#######################################################
#######################################################
def generate(prompt):
res = text_generation_pipe(prompt,
max_length=140,
repetition_penalty=1.1)[0]['generated_text']
i=0
while res[-1] != '.' and i < 30:
res = text_generation_pipe(res, max_length=1)[0]['generated_text']
i += 1
return res
def generate_story(prompt):
print(f'Prompt={prompt}')
story = generate(prompt=prompt)
summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text'].split('.')[0]
return story, summary, gr.update(visible=True)
def on_change_event(app_state):
if app_state and app_state['running'] and app_state['img']:
img = app_state['img']
step = app_state['step']
app_state['dots'] += 1
app_state['dots'] = app_state['dots'] % 10
message = app_state['status_msg'] + ' *' * app_state['dots']
return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message)
else:
return gr.update(label='images list'), gr.update(value='')
with gr.Blocks() as demo:
def generate_image(prompt, inference_steps, app_state):
app_state['running'] = True
app_state['img_list'] = []
app_state['status_msg'] = 'Starting'
def callback(step, ts, latents):
app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}'
latents = 1 / 0.18215 * latents
res = pipe.vae.decode(latents).sample
res = (res / 2 + 0.5).clamp(0, 1)
res = res.cpu().permute(0, 2, 3, 1).detach().numpy()
res = pipe.numpy_to_pil(res)[0]
app_state['img'] = res
app_state['step'] = step
app_state['img_list'].append(res)
app_state['status_msg'] = f'Generating step ({step + 1})'
prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration mdjrny-v4 style'
img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=2)
app_state['running'] = False
app_state['img'] = None
app_state['status_msg'] = ''
app_state['dots'] = 0
return gr.update(value=img.images[0], label='Generated image')
app_state = gr.State({'img': None,
'step':0,
'running':False,
'status_msg': '',
'img_list': [],
'dots': 0
})
title = gr.Markdown('## Lord of the rings app')
description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy {SAVED_CHECKPOINT}.'
f' The text2img model is {model_id}. The summarization is done using distilbart.')
prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
story = gr.Textbox(label="Your story")
summary = gr.Textbox(label="Summary")
bt_make_text = gr.Button("Generate text")
bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
img_description = gr.Markdown('Image generation takes some time'
' but here you can see what is generated from the latent state of the diffuser every few steps.'
' Usually there is a significant improvement around step 12 that yields a much better image')
status_msg = gr.Markdown('')
gallery = gr.Gallery()
image = gr.Image(label='Illustration for your story', show_label=True)
gallery.style(grid=[4])
inference_steps = gr.Slider(5, 30,
value=20,
step=1,
visible=True,
label=f"Num inference steps (more steps yields a better image but takes more time)")
bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image)
eventslider = gr.Slider(visible=False)
dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5)
eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep])
if READ_TOKEN:
demo.queue().launch()
else:
demo.queue().launch(share=True, debug=True)