|
import os |
|
import gradio as gr |
|
|
|
import torch |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import pipeline |
|
from diffusers import StableDiffusionPipeline |
|
|
|
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) |
|
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
|
|
|
|
has_cuda = torch.cuda.is_available() |
|
|
|
if has_cuda: |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) |
|
device = "cuda" |
|
else: |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN) |
|
device = "cpu" |
|
|
|
pipe.to(device) |
|
def safety_checker(images, clip_input): |
|
return images, False |
|
pipe.safety_checker = safety_checker |
|
|
|
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' |
|
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) |
|
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) |
|
|
|
summarizer = pipeline("summarization") |
|
|
|
def break_until_dot(txt): |
|
return txt.rsplit('.', 1)[0] + '.' |
|
|
|
def generate(prompt): |
|
input_context = prompt |
|
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device) |
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=180, |
|
temperature=0.7, |
|
num_return_sequences=3, |
|
do_sample=True |
|
) |
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return break_until_dot(decoded) |
|
|
|
def generate_image(prompt, inference_steps): |
|
prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration' |
|
img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps) |
|
return img.images[0] |
|
|
|
def generate_story(prompt): |
|
story = generate(prompt=prompt) |
|
summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text'] |
|
summary = break_until_dot(summary) |
|
return story, summary, gr.update(visible=True) |
|
|
|
with gr.Blocks() as demo: |
|
title = gr.Markdown('## Lord of the rings app') |
|
description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.' |
|
f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.' |
|
f' The text2img model is {model_id}. The summarization is done using distilbart.') |
|
prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and") |
|
story = gr.Textbox(label="Your story") |
|
summary = gr.Textbox(label="Summary") |
|
|
|
bt_make_text = gr.Button("Generate text") |
|
bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False) |
|
|
|
image = gr.Image(label='Illustration for your story', shape=(512, 512)) |
|
inference_steps = gr.Slider(5, 30, value=15, step=1, label=f"Num inference steps (more steps makes a better image but takes more time)") |
|
|
|
bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image]) |
|
bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps], outputs=image) |
|
|
|
if READ_TOKEN: |
|
demo.launch() |
|
else: |
|
demo.launch(share=True, debug=True) |