|
import os |
|
import gradio as gr |
|
|
|
import torch |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import pipeline |
|
from diffusers import StableDiffusionPipeline |
|
|
|
summarizer = pipeline("summarization") |
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
|
|
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' |
|
MIN_WORDS = 120 |
|
|
|
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) |
|
|
|
def get_image_pipe(): |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) |
|
pipe.to('cuda') |
|
return pipe |
|
|
|
def get_model(): |
|
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) |
|
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) |
|
return model, tokenizer |
|
|
|
|
|
def generate(prompt): |
|
model, tokenizer = get_model() |
|
|
|
input_context = prompt |
|
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device) |
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=100, |
|
temperature=0.7, |
|
num_return_sequences=3, |
|
do_sample=True |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.' |
|
|
|
def make_image(prompt): |
|
pipe = get_image_pipe() |
|
image = pipe(prompt).images[0] |
|
|
|
def predict(prompt): |
|
story = generate(prompt=prompt) |
|
summary = summarizer(story, min_length=5, max_length=20)[0]['summary_text'] |
|
image = make_image(summary) |
|
return story, summarizer(story, min_length=5, max_length=20), image |
|
|
|
|
|
title = "Lord of the rings app" |
|
description = """A Lord of the rings insired app that combines text and image generation""" |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs="textbox", |
|
outputs=["text", "text", "image"], |
|
title=title, |
|
description=description, |
|
examples=[["My new adventure would be"], ["Then I a hobbit appeared"], ["Frodo told me"]] |
|
).launch(share=True) |