mikegarts commited on
Commit
3e4033b
1 Parent(s): 4efa0ed

rewrite app as blocks and use image generation

Browse files
Files changed (1) hide show
  1. app.py +51 -39
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import gradio as gr
3
 
4
  import torch
@@ -7,60 +6,73 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from transformers import pipeline
8
  from diffusers import StableDiffusionPipeline
9
 
10
- summarizer = pipeline("summarization")
11
- model_id = "runwayml/stable-diffusion-v1-5"
12
-
13
- SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
14
- MIN_WORDS = 120
15
-
16
  READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
17
 
18
- def get_image_pipe():
 
 
 
 
 
19
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
20
- pipe.to('cuda')
21
- return pipe
 
 
 
 
 
 
 
22
 
23
- def get_model():
24
- model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
25
- tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
26
- return model, tokenizer
27
 
 
28
 
29
- def generate(prompt):
30
- model, tokenizer = get_model()
31
 
 
32
  input_context = prompt
33
  input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
34
 
35
  outputs = model.generate(
36
  input_ids=input_ids,
37
- max_length=100,
38
  temperature=0.7,
39
  num_return_sequences=3,
40
  do_sample=True
41
  )
 
 
42
 
43
- return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.'
44
-
45
- def make_image(prompt):
46
- pipe = get_image_pipe()
47
- image = pipe(prompt)["sample"][0]
48
 
49
- def predict(prompt):
50
  story = generate(prompt=prompt)
51
- summary = summarizer(story, min_length=5, max_length=20)[0]['summary_text']
52
- image = make_image(summary)
53
- return story, summarizer(story, min_length=5, max_length=20), image
54
-
55
-
56
- title = "Lord of the rings app"
57
- description = """A Lord of the rings insired app that combines text and image generation"""
58
-
59
- gr.Interface(
60
- fn=predict,
61
- inputs="textbox",
62
- outputs=["text", "text", "image"],
63
- title=title,
64
- description=description,
65
- examples=[["My new adventure would be"], ["Then I a hobbit appeared"], ["Frodo told me"]]
66
- ).launch(debug=True)
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
  import torch
 
6
  from transformers import pipeline
7
  from diffusers import StableDiffusionPipeline
8
 
 
 
 
 
 
 
9
  READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None)
10
 
11
+ model_id = "runwayml/stable-diffusion-v1-5"
12
+ # model_id = "CompVis/stable-diffusion-v1-4"
13
+
14
+ has_cuda = torch.cuda.is_available()
15
+ device = "cpu"
16
+ if has_cuda:
17
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN)
18
+ device = "cuda"
19
+ else:
20
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, revision="fp16", use_auth_token=READ_TOKEN)
21
+ device = "cpu"
22
+
23
+ pipe.to(device)
24
+ def safety_checker(images, clip_input):
25
+ return images, False
26
+ pipe.safety_checker = safety_checker
27
 
28
+ SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
29
+ model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
30
+ tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
 
31
 
32
+ summarizer = pipeline("summarization")
33
 
34
+ def break_until_dot(txt):
35
+ return txt.rsplit('.', 1)[0] + '.'
36
 
37
+ def generate(prompt):
38
  input_context = prompt
39
  input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
40
 
41
  outputs = model.generate(
42
  input_ids=input_ids,
43
+ max_length=180,
44
  temperature=0.7,
45
  num_return_sequences=3,
46
  do_sample=True
47
  )
48
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ return break_until_dot(decoded)
50
 
51
+ def generate_image(prompt, inference_steps):
52
+ prompt = prompt + ', masterpiece charcoal pencil art lord of the rings illustration'
53
+ img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps)
54
+ return img.images[0]
 
55
 
56
+ def generate_story(prompt):
57
  story = generate(prompt=prompt)
58
+ summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text']
59
+ summary = break_until_dot(summary)
60
+ return story, summary, gr.update(visible=True)
61
+
62
+ with gr.Blocks() as demo:
63
+ title = gr.Markdown('## Lord of the rings app')
64
+ description = gr.Markdown('### A Lord of the rings insired app that combines text and image generation')
65
+ prompt = gr.Textbox(label="Your prompt", value="And then the hobbit said")
66
+ story = gr.Textbox(label="Your story")
67
+ summary = gr.Textbox(label="Summary")
68
+
69
+ bt_make_text = gr.Button("Generate text")
70
+ bt_make_image = gr.Button("Generate and image (takes about 10-15 minutes on CPU)", visible=False)
71
+
72
+ image = gr.Image(label='Illustration for your story')
73
+ inference_steps = gr.Slider(5, 35, value=15, step=1, label="Num inference steps (more steps makes a better image but takes more time)")
74
+
75
+ bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image])
76
+ bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps], outputs=image)
77
+
78
+ demo.launch(share=True)