krishnv commited on
Commit
947d2f8
1 Parent(s): 38284f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -1,25 +1,28 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
4
 
5
  device = 'cpu'
6
  encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
7
  decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
  model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
- feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
11
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
12
 
13
  def predict(image, max_length=64, num_beams=4):
14
  image = image.convert('RGB')
15
- image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
16
  clean_text = lambda x: x.replace('', '').split('\n')[0]
17
  caption_ids = model.generate(image, max_length=max_length, num_beams=num_beams)[0]
18
  caption_text = clean_text(tokenizer.decode(caption_ids, skip_special_tokens=True))
19
  return caption_text
20
 
21
- input_image = gr.inputs.Image(label="Upload your Image", type='pil', optional=True)
22
- output_text = gr.outputs.Textbox(type="text", label="Captions")
 
23
 
24
  examples = [f"example{i}.jpg" for i in range(1, 7)]
25
 
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, ViTImageProcessor, VisionEncoderDecoderModel
4
 
5
  device = 'cpu'
6
  encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
7
  decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
  model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
+
10
+ # Replace ViTFeatureExtractor with ViTImageProcessor
11
+ feature_extractor = ViTImageProcessor.from_pretrained(encoder_checkpoint)
12
  tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
13
  model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
14
 
15
  def predict(image, max_length=64, num_beams=4):
16
  image = image.convert('RGB')
17
+ image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
18
  clean_text = lambda x: x.replace('', '').split('\n')[0]
19
  caption_ids = model.generate(image, max_length=max_length, num_beams=num_beams)[0]
20
  caption_text = clean_text(tokenizer.decode(caption_ids, skip_special_tokens=True))
21
  return caption_text
22
 
23
+ # Updated to use new Gradio API
24
+ input_image = gr.Image(label="Upload your Image", type='pil', optional=True)
25
+ output_text = gr.Textbox(label="Captions")
26
 
27
  examples = [f"example{i}.jpg" for i in range(1, 7)]
28