lotr / app.py
mikegarts's picture
Update app.py
61d1651
raw
history blame
1.92 kB
import os
import gradio as gr
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from diffusers import StableDiffusionPipeline
summarizer = pipeline("summarization")
model_id = "runwayml/stable-diffusion-v1-5"
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
MIN_WORDS = 120
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
def get_image_pipe():
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
pipe.to('cuda')
return pipe
def get_model():
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
return model, tokenizer
def generate(prompt):
model, tokenizer = get_model()
input_context = prompt
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
outputs = model.generate(
input_ids=input_ids,
max_length=100,
temperature=0.7,
num_return_sequences=3,
do_sample=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.'
def make_image(prompt):
pipe = get_image_pipe()
image = pipe(prompt).images[0]
def predict(prompt):
story = generate(prompt=prompt)
summary = summarizer(story, min_length=5, max_length=20)[0]['summary_text']
image = make_image(summary)
return story, summarizer(story, min_length=5, max_length=20), image
title = "Lord of the rings app"
description = """A Lord of the rings insired app that combines text and image generation"""
gr.Interface(
fn=predict,
inputs="textbox",
outputs=["text", "text", "image"],
title=title,
description=description,
examples=[["My new adventure would be"], ["Then I a hobbit appeared"], ["Frodo told me"]]
).launch(share=True)