mikegarts commited on
Commit
05d7e61
1 Parent(s): bcede34

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import transformers
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from transformers import pipeline
7
+ from diffusers import StableDiffusionPipeline
8
+
9
+ summarizer = pipeline("summarization")
10
+ model_id = "runwayml/stable-diffusion-v1-5"
11
+
12
+ SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
13
+ MIN_WORDS = 120
14
+
15
+ def get_image_pipe():
16
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
17
+ pipe.to(pipe.device)
18
+ return pipe
19
+
20
+ def get_model():
21
+ model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
22
+ tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
23
+ return model, tokenizer
24
+
25
+
26
+ def generate(prompt):
27
+ model, tokenizer = get_model()
28
+
29
+ input_context = prompt
30
+ input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
31
+
32
+ outputs = model.generate(
33
+ input_ids=input_ids,
34
+ max_length=100,
35
+ temperature=0.7,
36
+ num_return_sequences=3,
37
+ do_sample=True
38
+ )
39
+
40
+ return tokenizer.decode(outputs[0], skip_special_tokens=True).rsplit('.', 1)[0] + '.'
41
+
42
+ def make_image(prompt):
43
+ pipe = get_image_pipe()
44
+ image = pipe(prompt).images[0]
45
+
46
+ def predict(prompt):
47
+ story = generate(prompt=prompt)
48
+ summary = summarizer(story, min_length=5, max_length=20)[0]['summary_text']
49
+ image = make_image(summary)
50
+ return story, summarizer(story, min_length=5, max_length=20), image
51
+
52
+
53
+ title = "Lord of the rings app"
54
+ description = """A Lord of the rings insired app that combines text and image generation"""
55
+
56
+ gr.Interface(
57
+ fn=predict,
58
+ inputs="textbox",
59
+ outputs=["text", "text", "image"],
60
+ title=title,
61
+ description=description,
62
+ examples=[["My new adventure would be"], ["Then I a hobbit appeared"], ["Frodo told me"]]
63
+ ).launch(share=True)