import time import os import PIL import gradio as gr import torch import transformers # from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import pipeline from transformers import pipeline from diffusers import StableDiffusionPipeline READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) model_id = "runwayml/stable-diffusion-v1-5" # model_id = "CompVis/stable-diffusion-v1-4" has_cuda = torch.cuda.is_available() if has_cuda: pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) device = "cuda" else: pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN) device = "cpu" pipe.to(device) def safety_checker(images, clip_input): return images, False pipe.safety_checker = safety_checker SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' text_generation_pipe = pipeline("text-generation", model=SAVED_CHECKPOINT) summarizer = pipeline("summarization") ####################################################### ####################################################### def generate(prompt): res = text_generation_pipe(prompt, max_length=140, repetition_penalty=1.1)[0]['generated_text'] i=0 while res[-1] != '.' and i < 30: res = text_generation_pipe(res, max_length=1)[0]['generated_text'] i += 1 return res def generate_story(prompt): print(f'Prompt={prompt}') story = generate(prompt=prompt) summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text'].split('.')[0] return story, summary, gr.update(visible=True) def on_change_event(app_state): # print(f'on_change_event {app_state}') if app_state and app_state['running'] and app_state['img']: img = app_state['img'] step = app_state['step'] # print(f'Updating the image:! {app_state}') app_state['dots'] += 1 app_state['dots'] = app_state['dots'] % 10 message = app_state['status_msg'] + ' *' * app_state['dots'] # print (f'message={message}') return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message) else: return gr.update(label='images list'), gr.update(value='') with gr.Blocks() as demo: def generate_image(prompt, inference_steps, app_state): app_state['running'] = True app_state['img_list'] = [] app_state['status_msg'] = 'Starting' def callback(step, ts, latents): app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}' latents = 1 / 0.18215 * latents res = pipe.vae.decode(latents).sample res = (res / 2 + 0.5).clamp(0, 1) res = res.cpu().permute(0, 2, 3, 1).detach().numpy() res = pipe.numpy_to_pil(res)[0] app_state['img'] = res app_state['step'] = step app_state['img_list'].append(res) app_state['status_msg'] = f'Generating step ({step + 1})' prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration' img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=2) app_state['running'] = False app_state['img'] = None app_state['status_msg'] = '' app_state['dots'] = 0 return gr.update(value=img.images[0], label='Generated image') app_state = gr.State({'img': None, 'step':0, 'running':False, 'status_msg': '', 'img_list': [], 'dots': 0 }) title = gr.Markdown('## Lord of the rings app') description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.' f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy {SAVED_CHECKPOINT}.' f' The text2img model is {model_id}. The summarization is done using distilbart.') prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and") story = gr.Textbox(label="Your story") summary = gr.Textbox(label="Summary") bt_make_text = gr.Button("Generate text") bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False) img_description = gr.Markdown('Image generation takes some time' ' but here you can see what is generated from the latent state of the diffuser every few steps.' ' Usually there is a significant improvement around step 12 that yields a much better image') status_msg = gr.Markdown('') gallery = gr.Gallery() image = gr.Image(label='Illustration for your story', show_label=True) gallery.style(grid=[4]) inference_steps = gr.Slider(5, 30, value=20, step=1, visible=True, label=f"Num inference steps (more steps yields a better image but takes more time)") bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image]) bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image) eventslider = gr.Slider(visible=False) dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5) eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep]) if READ_TOKEN: demo.queue().launch() else: demo.queue().launch(share=True, debug=True)