import os import gradio as gr import torch import transformers from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import pipeline from diffusers import StableDiffusionPipeline READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) model_id = "runwayml/stable-diffusion-v1-5" # model_id = "CompVis/stable-diffusion-v1-4" has_cuda = torch.cuda.is_available() if has_cuda: pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) device = "cuda" else: pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN) device = "cpu" pipe.to(device) def safety_checker(images, clip_input): return images, False pipe.safety_checker = safety_checker SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) summarizer = pipeline("summarization") def break_until_dot(txt): return txt.rsplit('.', 1)[0] + '.' def generate(prompt): input_context = prompt input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device) outputs = model.generate( input_ids=input_ids, max_length=180, temperature=0.7, num_return_sequences=3, do_sample=True ) decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) return break_until_dot(decoded) def generate_image(prompt, inference_steps): prompt = prompt + ', masterpiece charcoal pencil art lord of the rings illustration' img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps) return img.images[0] def generate_story(prompt): story = generate(prompt=prompt) summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text'] summary = break_until_dot(summary) return story, summary, gr.update(visible=True) with gr.Blocks() as demo: title = gr.Markdown('## Lord of the rings app') description = gr.Markdown('### A Lord of the rings insired app that combines text and image generation') prompt = gr.Textbox(label="Your prompt", value="And then the hobbit said") story = gr.Textbox(label="Your story") summary = gr.Textbox(label="Summary") bt_make_text = gr.Button("Generate text") bt_make_image = gr.Button("Generate and image (takes about 10-15 minutes on CPU)", visible=False) image = gr.Image(label='Illustration for your story') inference_steps = gr.Slider(5, 30, value=15, step=1, label="Num inference steps (more steps makes a better image but takes more time)") bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image]) bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps], outputs=image) if READ_TOKEN: demo.launch() else: demo.launch(share=True, debug=True)