File size: 3,205 Bytes
34d353a 05d7e61 79c9a48 3e4033b 34d353a 3e4033b 79c9a48 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b 815cc40 3e4033b 05d7e61 3e4033b 05d7e61 3e4033b af893f9 be76889 7e9b8ea 3e4033b 815cc40 3e4033b 815cc40 3e4033b 34d353a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import gradio as gr
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from diffusers import StableDiffusionPipeline
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "CompVis/stable-diffusion-v1-4"
has_cuda = torch.cuda.is_available()
if has_cuda:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
device = "cuda"
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN)
device = "cpu"
pipe.to(device)
def safety_checker(images, clip_input):
return images, False
pipe.safety_checker = safety_checker
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
summarizer = pipeline("summarization")
def break_until_dot(txt):
return txt.rsplit('.', 1)[0] + '.'
def generate(prompt):
input_context = prompt
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
outputs = model.generate(
input_ids=input_ids,
max_length=180,
temperature=0.7,
num_return_sequences=3,
do_sample=True
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return break_until_dot(decoded)
def generate_image(prompt, inference_steps):
prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps)
return img.images[0]
def generate_story(prompt):
story = generate(prompt=prompt)
summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
summary = break_until_dot(summary)
return story, summary, gr.update(visible=True)
with gr.Blocks() as demo:
title = gr.Markdown('## Lord of the rings app')
description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
f' The text2img model is {model_id}. The summarization is done using distilbart.')
prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
story = gr.Textbox(label="Your story")
summary = gr.Textbox(label="Summary")
bt_make_text = gr.Button("Generate text")
bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
image = gr.Image(label='Illustration for your story', shape=(512, 512))
inference_steps = gr.Slider(5, 30, value=15, step=1, label=f"Num inference steps (more steps makes a better image but takes more time)")
bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps], outputs=image)
if READ_TOKEN:
demo.launch()
else:
demo.launch(share=True, debug=True) |