upgrade to gradio blocks

#1
by abd-meda - opened
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import torch
2
  import re
 
 
3
  import gradio as gr
4
- from pathlib import Path
5
  from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
6
 
7
-
8
  # Pattern to ignore all the text after 2 or more full stops
9
  regex_pattern = "[.]{2,}"
10
 
@@ -19,6 +19,10 @@ def post_process(text):
19
  return text
20
 
21
 
 
 
 
 
22
  def predict(image, max_length=64, num_beams=4):
23
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
24
  pixel_values = pixel_values.to(device)
@@ -52,29 +56,29 @@ print("Loaded feature_extractor")
52
  tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
53
  if model.decoder.name_or_path == "gpt2":
54
  tokenizer.pad_token = tokenizer.eos_token
55
-
56
  print("Loaded tokenizer")
57
 
58
- title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
59
- description = ""
60
-
61
- input = gr.inputs.Image(type="pil")
62
-
63
- example_images = sorted(
64
- [f.as_posix() for f in Path("examples").glob("*.jpg")]
65
- )
66
- print(f"Loaded {len(example_images)} example images")
67
-
68
- interface = gr.Interface(
69
- fn=predict,
70
- inputs=input,
71
- outputs="textbox",
72
- title=title,
73
- description=description,
74
- examples=example_images,
75
- examples_per_page=20,
76
- live=True,
77
- article='<p>Made by: <a href="https://twitter.com/kartik_godawat" target="_blank" rel="noopener noreferrer">dk-crazydiv</a> and <a href="https://twitter.com/dsr_ai" target="_blank" rel="noopener noreferrer">dsr</a></p>'
78
- )
79
-
80
- interface.launch()
 
 
1
+ import os
2
  import re
3
+
4
+ import torch
5
  import gradio as gr
 
6
  from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
7
 
 
8
  # Pattern to ignore all the text after 2 or more full stops
9
  regex_pattern = "[.]{2,}"
10
 
 
19
  return text
20
 
21
 
22
+ def set_example_image(example: list) -> dict:
23
+ return gr.Image.update(value=example[0])
24
+
25
+
26
  def predict(image, max_length=64, num_beams=4):
27
  pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
28
  pixel_values = pixel_values.to(device)
 
56
  tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
57
  if model.decoder.name_or_path == "gpt2":
58
  tokenizer.pad_token = tokenizer.eos_token
 
59
  print("Loaded tokenizer")
60
 
61
+ examples = [[f"examples/{filename}"] for filename in next(os.walk('examples'), (None, None, []))[2]]
62
+ print(f"Loaded {len(examples)} example images")
63
+
64
+ with gr.Blocks(css="#title { margin: 0 auto; padding: 25px 25px 25px 25px }") as poster2plot:
65
+ with gr.Column():
66
+ with gr.Row():
67
+ gr.Markdown("# Poster2Plot: Upload a Movie/T.V show poster to generate a plot", elem_id='title')
68
+ with gr.Row():
69
+ with gr.Column():
70
+ with gr.Row():
71
+ input_image = gr.Image(label='Input Image', type='numpy')
72
+ with gr.Row():
73
+ submit_button = gr.Button(value="Submit", variant='primary')
74
+ with gr.Column():
75
+ plot = gr.Textbox(label="Plot")
76
+ with gr.Row():
77
+ example_images = gr.Dataset(components=[input_image], samples=examples)
78
+ with gr.Row():
79
+ gr.Markdown("Made by: [dk-crazydiv](https://twitter.com/kartik_godawat) and [dsr](https://twitter.com/dsr_ai)")
80
+
81
+ submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
82
+ example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
83
+
84
+ poster2plot.launch()