File size: 2,959 Bytes
34d353a
05d7e61
 
 
 
 
 
 
 
79c9a48
 
3e4033b
 
 
 
34d353a
3e4033b
79c9a48
3e4033b
 
 
 
 
 
 
 
 
05d7e61
3e4033b
 
 
05d7e61
3e4033b
05d7e61
3e4033b
 
05d7e61
3e4033b
05d7e61
 
 
 
 
3e4033b
05d7e61
 
 
 
3e4033b
 
05d7e61
3e4033b
 
 
 
05d7e61
3e4033b
05d7e61
3e4033b
 
 
 
 
 
 
7e9b8ea
3e4033b
 
 
 
 
 
 
34d353a
3e4033b
 
 
 
34d353a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import gradio as gr

import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from diffusers import StableDiffusionPipeline

READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)

model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "CompVis/stable-diffusion-v1-4"

has_cuda = torch.cuda.is_available()

if has_cuda:
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
    device = "cuda"
else:
    pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN)
    device = "cpu"
    
pipe.to(device)
def safety_checker(images, clip_input):
    return images, False
pipe.safety_checker = safety_checker

SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)

summarizer = pipeline("summarization")

def break_until_dot(txt):
    return txt.rsplit('.', 1)[0] + '.'

def generate(prompt):
    input_context = prompt
    input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)

    outputs = model.generate(
        input_ids=input_ids, 
        max_length=180, 
        temperature=0.7, 
        num_return_sequences=3, 
        do_sample=True
    )
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return break_until_dot(decoded)

def generate_image(prompt, inference_steps):
    prompt = prompt + ', masterpiece charcoal pencil art lord of the rings illustration'
    img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps)
    return img.images[0]

def generate_story(prompt):
    story = generate(prompt=prompt)
    summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
    summary = break_until_dot(summary)
    return story, summary, gr.update(visible=True)

with gr.Blocks() as demo:
    title = gr.Markdown('## Lord of the rings app')
    description = gr.Markdown('### A Lord of the rings insired app that combines text and image generation')
    prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
    story = gr.Textbox(label="Your story")
    summary = gr.Textbox(label="Summary")
    
    bt_make_text = gr.Button("Generate text")
    bt_make_image = gr.Button("Generate and image (takes about 10-15 minutes on CPU)", visible=False)
    
    image = gr.Image(label='Illustration for your story')
    inference_steps = gr.Slider(5, 30, value=15, step=1, label="Num inference steps (more steps makes a better image but takes more time)")
    
    bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
    bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps], outputs=image)

if READ_TOKEN:
    demo.launch()
else:
    demo.launch(share=True, debug=True)