File size: 5,783 Bytes
d523bde
34d353a
d523bde
05d7e61
 
 
 
7ecd55b
 
05d7e61
 
 
79c9a48
 
3e4033b
 
 
 
34d353a
3e4033b
79c9a48
3e4033b
 
d523bde
3e4033b
 
 
 
 
 
05d7e61
3e4033b
7ecd55b
05d7e61
3e4033b
05d7e61
d523bde
 
7ecd55b
 
3e4033b
 
05d7e61
3e4033b
7ecd55b
 
05d7e61
 
3e4033b
05d7e61
3e4033b
 
 
 
d523bde
 
 
 
 
 
 
 
 
 
 
 
 
 
3e4033b
d523bde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53bdc8a
d523bde
 
 
 
 
 
 
 
 
 
 
 
 
3e4033b
af893f9
7ecd55b
 
7e9b8ea
3e4033b
 
 
 
815cc40
3e4033b
d523bde
 
 
7ecd55b
d523bde
 
 
 
 
 
 
 
 
 
 
 
3e4033b
 
d523bde
 
 
 
 
 
3e4033b
34d353a
d523bde
34d353a
7ecd55b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import time
import os
import PIL
import gradio as gr

import torch
import transformers
# from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline
from transformers import pipeline
from diffusers import StableDiffusionPipeline

READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)

model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "CompVis/stable-diffusion-v1-4"

has_cuda = torch.cuda.is_available()

if has_cuda:
    pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
    device = "cuda"
else:
    pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN)
    device = "cpu"
    
pipe.to(device)
def safety_checker(images, clip_input):
    return images, False
pipe.safety_checker = safety_checker

SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
text_generation_pipe = pipeline("text-generation", model=SAVED_CHECKPOINT)

summarizer = pipeline("summarization")

#######################################################

#######################################################

def break_until_dot(txt):
    return txt.rsplit('.', 1)[0] + '.'

def generate(prompt):
    generated = text_generation_pipe(prompt, max_length=140)[0]['generated_text']
    return break_until_dot(generated)


def generate_story(prompt):
    story = generate(prompt=prompt)
    summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
    summary = break_until_dot(summary)
    return story, summary, gr.update(visible=True)

def on_change_event(app_state):
    print(f'on_change_event {app_state}')
    if app_state and app_state['running'] and app_state['img']:
        img = app_state['img']
        step = app_state['step']
        print(f'Updating the image:! {app_state}')
        app_state['dots'] += 1
        app_state['dots'] = app_state['dots'] % 10
        message = app_state['status_msg'] + ' *' * app_state['dots']
        print (f'message={message}')
        return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message)
    else:
        return gr.update(label='images list'), gr.update(value='')

with gr.Blocks() as demo:
    
    def generate_image(prompt, inference_steps, app_state):
        app_state['running'] = True
        app_state['img_list'] = []
        app_state['status_msg'] = 'Starting'
        def callback(step, ts, latents):
            app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}'
            latents = 1 / 0.18215 * latents
            res = pipe.vae.decode(latents).sample
            res = (res / 2 + 0.5).clamp(0, 1)
            res = res.cpu().permute(0, 2, 3, 1).detach().numpy()
            res = pipe.numpy_to_pil(res)[0]
            app_state['img'] = res
            app_state['step'] = step
            app_state['img_list'].append(res)
            app_state['status_msg'] = f'Generating step ({step + 1})'
        
        prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration'
        img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=2)
        app_state['running'] = False
        app_state['img'] = None
        app_state['status_msg'] = ''
        app_state['dots'] = 0
        return gr.update(value=img.images[0], label='Generated image')
    
    app_state = gr.State({'img': None, 
                          'step':0, 
                          'running':False,
                          'status_msg': '',
                          'img_list': [],
                          'dots': 0
                         })
    title = gr.Markdown('## Lord of the rings app')
    description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
                              f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy ({SAVED_CHECKPOINT}).'
                              f' The text2img model is {model_id}. The summarization is done using the default distilbart.')
    prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
    story = gr.Textbox(label="Your story")
    summary = gr.Textbox(label="Summary")
    
    bt_make_text = gr.Button("Generate text")
    bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False)
    
    img_description = gr.Markdown('Image generation takes some time'
                                  ' but here you can see what is generated from the latent state of the diffuser every few steps.'
                                  ' Usually there is a significant improvement around step 12 that yields a much better image')
    status_msg = gr.Markdown('')
    
    gallery = gr.Gallery()
    image = gr.Image(label='Illustration for your story', show_label=True)

    gallery.style(grid=[4])
    
    inference_steps = gr.Slider(5, 30, 
                                value=20, 
                                step=1, 
                                visible=True,
                                label=f"Num inference steps (more steps yields a better image but takes more time)")
    
    
    bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
    bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image)
    
    eventslider = gr.Slider(visible=False)
    dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5)
    eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep])
    

if READ_TOKEN:
    demo.queue().launch()
else:
    demo.queue().launch(share=True, debug=True)