krishnv commited on
Commit
6bb6d88
1 Parent(s): 2e70508

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -26
app.py CHANGED
@@ -1,6 +1,5 @@
1
  #From
2
  import torch
3
- import re
4
  import gradio as gr
5
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
6
 
@@ -12,34 +11,34 @@ feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
12
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
13
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
14
 
 
 
 
 
 
 
 
15
 
16
- def predict(image,max_length=64, num_beams=4):
17
- image = image.convert('RGB')
18
- image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
19
- clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
20
- caption_ids = model.generate(image, max_length = max_length)[0]
21
- caption_text = clean_text(tokenizer.decode(caption_ids))
22
- return caption_text
23
 
 
24
 
25
-
26
- input = gr.inputs.Image(label="Upload your Image", type = 'pil', optional=True)
27
- output = gr.outputs.Textbox(type="auto",label="Captions")
28
- examples = [f"example{i}.jpg" for i in range(1,7)]
29
-
30
- description= "Image captioning application made using transformers"
31
  title = "Image Captioning 🖼️"
 
32
 
33
- article = "Created By : Shreyas Dixit "
34
-
35
  interface = gr.Interface(
36
- fn=predict,
37
- inputs = input,
38
- theme="grass",
39
- outputs=output,
40
- examples = examples,
41
- title=title,
42
- description=description,
43
- article = article,
44
- )
45
- interface.launch(debug=True)
 
 
 
1
  #From
2
  import torch
 
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
 
14
+ def predict(image, max_length=64, num_beams=4):
15
+ image = image.convert('RGB')
16
+ image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
17
+ clean_text = lambda x: x.replace('','').split('\n')[0]
18
+ caption_ids = model.generate(image, max_length=max_length, num_beams=num_beams)[0]
19
+ caption_text = clean_text(tokenizer.decode(caption_ids, skip_special_tokens=True))
20
+ return caption_text
21
 
22
+ input_image = gr.inputs.Image(label="Upload your Image", type='pil', optional=True)
23
+ output_text = gr.outputs.Textbox(type="text", label="Captions")
 
 
 
 
 
24
 
25
+ examples = [f"example{i}.jpg" for i in range(1, 7)]
26
 
27
+ description = "Image captioning application made using transformers"
 
 
 
 
 
28
  title = "Image Captioning 🖼️"
29
+ article = "Created By : Shreyas Dixit"
30
 
31
+ # Create the Gradio interface
 
32
  interface = gr.Interface(
33
+ fn=predict,
34
+ inputs=input_image,
35
+ outputs=output_text,
36
+ examples=examples,
37
+ title=title,
38
+ description=description,
39
+ article=article,
40
+ theme="grass"
41
+ )
42
+
43
+ # Launch the interface
44
+ interface.launch(debug=True,share=True)