mikegarts commited on
Commit
7ecd55b
1 Parent(s): 53bdc8a

refactor text generation code

Browse files
Files changed (1) hide show
  1. app.py +11 -20
app.py CHANGED
@@ -5,7 +5,8 @@ import gradio as gr
5
 
6
  import torch
7
  import transformers
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
  from transformers import pipeline
10
  from diffusers import StableDiffusionPipeline
11
 
@@ -29,30 +30,20 @@ def safety_checker(images, clip_input):
29
  pipe.safety_checker = safety_checker
30
 
31
  SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
32
- model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT)
33
- tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT)
34
 
35
  summarizer = pipeline("summarization")
36
 
37
  #######################################################
38
 
 
 
39
  def break_until_dot(txt):
40
  return txt.rsplit('.', 1)[0] + '.'
41
 
42
  def generate(prompt):
43
- input_context = prompt
44
- input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device)
45
-
46
- outputs = model.generate(
47
- input_ids=input_ids,
48
- max_length=120,
49
- min_length=50,
50
- temperature=0.7,
51
- num_return_sequences=3,
52
- do_sample=True
53
- )
54
- decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
55
- return break_until_dot(decoded)
56
 
57
 
58
  def generate_story(prompt):
@@ -110,8 +101,8 @@ with gr.Blocks() as demo:
110
  })
111
  title = gr.Markdown('## Lord of the rings app')
112
  description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
113
- f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.'
114
- f' The text2img model is {model_id}. The summarization is done using distilbart.')
115
  prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
116
  story = gr.Textbox(label="Your story")
117
  summary = gr.Textbox(label="Summary")
@@ -122,7 +113,7 @@ with gr.Blocks() as demo:
122
  img_description = gr.Markdown('Image generation takes some time'
123
  ' but here you can see what is generated from the latent state of the diffuser every few steps.'
124
  ' Usually there is a significant improvement around step 12 that yields a much better image')
125
- status_msg = gr.Markdown()
126
 
127
  gallery = gr.Gallery()
128
  image = gr.Image(label='Illustration for your story', show_label=True)
@@ -147,4 +138,4 @@ with gr.Blocks() as demo:
147
  if READ_TOKEN:
148
  demo.queue().launch()
149
  else:
150
- demo.queue().launch(share=True, debug=True)
 
5
 
6
  import torch
7
  import transformers
8
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from transformers import pipeline
10
  from transformers import pipeline
11
  from diffusers import StableDiffusionPipeline
12
 
 
30
  pipe.safety_checker = safety_checker
31
 
32
  SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr'
33
+ text_generation_pipe = pipeline("text-generation", model=SAVED_CHECKPOINT)
 
34
 
35
  summarizer = pipeline("summarization")
36
 
37
  #######################################################
38
 
39
+ #######################################################
40
+
41
  def break_until_dot(txt):
42
  return txt.rsplit('.', 1)[0] + '.'
43
 
44
  def generate(prompt):
45
+ generated = text_generation_pipe(prompt, max_length=140)[0]['generated_text']
46
+ return break_until_dot(generated)
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def generate_story(prompt):
 
101
  })
102
  title = gr.Markdown('## Lord of the rings app')
103
  description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.'
104
+ f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy ({SAVED_CHECKPOINT}).'
105
+ f' The text2img model is {model_id}. The summarization is done using the default distilbart.')
106
  prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and")
107
  story = gr.Textbox(label="Your story")
108
  summary = gr.Textbox(label="Summary")
 
113
  img_description = gr.Markdown('Image generation takes some time'
114
  ' but here you can see what is generated from the latent state of the diffuser every few steps.'
115
  ' Usually there is a significant improvement around step 12 that yields a much better image')
116
+ status_msg = gr.Markdown('')
117
 
118
  gallery = gr.Gallery()
119
  image = gr.Image(label='Illustration for your story', show_label=True)
 
138
  if READ_TOKEN:
139
  demo.queue().launch()
140
  else:
141
+ demo.queue().launch(share=True, debug=True)