import time import os import PIL import gradio as gr import torch import transformers from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import pipeline from diffusers import StableDiffusionPipeline READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) model_id = "runwayml/stable-diffusion-v1-5" # model_id = "CompVis/stable-diffusion-v1-4" has_cuda = torch.cuda.is_available() if has_cuda: pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) device = "cuda" else: pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN) device = "cpu" pipe.to(device) def safety_checker(images, clip_input): return images, False pipe.safety_checker = safety_checker SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) summarizer = pipeline("summarization") ####################################################### def break_until_dot(txt): return txt.rsplit('.', 1)[0] + '.' def generate(prompt): input_context = prompt input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device) outputs = model.generate( input_ids=input_ids, max_length=120, min_length=50, temperature=0.7, num_return_sequences=3, do_sample=True ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return break_until_dot(decoded) def generate_story(prompt): story = generate(prompt=prompt) summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text'] summary = break_until_dot(summary) return story, summary, gr.update(visible=True) def on_change_event(app_state): print(f'on_change_event {app_state}') if app_state and app_state['running'] and app_state['img']: img = app_state['img'] step = app_state['step'] print(f'Updating the image:! {app_state}') app_state['dots'] += 1 app_state['dots'] = app_state['dots'] % 10 message = app_state['status_msg'] + ' *' * app_state['dots'] print (f'message={message}') return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message) else: return gr.update(label='images list'), gr.update(value='') with gr.Blocks() as demo: def generate_image(prompt, inference_steps, app_state): app_state['running'] = True app_state['img_list'] = [] app_state['status_msg'] = 'Starting' def callback(step, ts, latents): app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}' latents = 1 / 0.18215 * latents res = pipe.vae.decode(latents).sample res = (res / 2 + 0.5).clamp(0, 1) res = res.cpu().permute(0, 2, 3, 1).detach().numpy() res = pipe.numpy_to_pil(res)[0] app_state['img'] = res app_state['step'] = step app_state['img_list'].append(res) app_state['status_msg'] = f'Generating step ({step + 1})' prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration' img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=1) app_state['running'] = False app_state['img'] = None app_state['status_msg'] = '' app_state['dots'] = 0 return gr.update(value=img.images[0], label='Generated image') app_state = gr.State({'img': None, 'step':0, 'running':False, 'status_msg': '', 'img_list': [], 'dots': 0 }) title = gr.Markdown('## Lord of the rings app') description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.' f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.' f' The text2img model is {model_id}. The summarization is done using distilbart.') prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and") story = gr.Textbox(label="Your story") summary = gr.Textbox(label="Summary") bt_make_text = gr.Button("Generate text") bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False) img_description = gr.Markdown('Image generation takes some time' ' but here you can see what is generated from the latent state of the diffuser every few steps.' ' Usually there is a significant improvement around step 12 that yields a much better image') status_msg = gr.Markdown() gallery = gr.Gallery() image = gr.Image(label='Illustration for your story', show_label=True) gallery.style(grid=[4]) inference_steps = gr.Slider(5, 30, value=20, step=1, visible=True, label=f"Num inference steps (more steps yields a better image but takes more time)") bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image]) bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image) eventslider = gr.Slider(visible=False) dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5) eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep]) if READ_TOKEN: demo.queue().launch() else: demo.queue().launch(share=True, debug=True)