ChandraP12330 commited on
Commit
c8b7141
1 Parent(s): cb5683a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -10
app.py CHANGED
@@ -1,24 +1,70 @@
1
  import streamlit as st
2
- from transformers import BlipForConditionalGeneration, BlipProcessor
 
3
 
4
  # Load the BLIP model and processor
5
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
6
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def generate_caption(image):
9
  # Preprocess the image
10
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
11
 
12
  # Generate caption using the BLIP model
13
- output_ids = model.generate(pixel_values, max_length=50, num_beams=4, early_stopping=True)
14
 
15
  # Decode the caption
16
- caption = processor.decode(output_ids[0], skip_special_tokens=True)
17
 
18
  return caption
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def main():
21
- st.title("Image Caption Generator")
22
 
23
  # Upload image
24
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
@@ -27,11 +73,13 @@ def main():
27
  # Display the uploaded image
28
  image = st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
29
 
30
- # Generate caption
31
- if st.button("Generate Caption"):
32
- with st.spinner("Generating caption..."):
33
  caption = generate_caption(uploaded_file.getvalue())
 
34
  st.success(f"Caption: {caption}")
 
35
 
36
  if __name__ == "__main__":
37
  main()
 
1
  import streamlit as st
2
+ from transformers import BlipForConditionalGeneration, BlipProcessor, CLIPProcessor, CLIPModel
3
+ import torch
4
 
5
  # Load the BLIP model and processor
6
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
7
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
8
+
9
+ # Load the CLIP model and processor
10
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
11
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
12
+
13
+ # Labels for classification
14
+ labels=['Arrest',
15
+ 'Arson',
16
+ 'Explosion',
17
+ 'public fight',
18
+ 'Normal',
19
+ 'Road Accident',
20
+ 'Robbery',
21
+ 'Shooting',
22
+ 'Stealing',
23
+ 'Vandalism',
24
+ 'Suspicious activity',
25
+ 'Tailgating',
26
+ 'Unauthorized entry',
27
+ 'Protest/Demonstration',
28
+ 'Drone suspicious activity',
29
+ 'Fire/Smoke detection',
30
+ 'Medical emergency',
31
+ 'Suspicious package/object',
32
+ 'Threatening',
33
+ 'Attack',
34
+ 'Shoplifting',
35
+ 'burglary ',
36
+ 'distress',
37
+ 'assault']
38
 
39
  def generate_caption(image):
40
  # Preprocess the image
41
+ pixel_values = blip_processor(images=image, return_tensors="pt").pixel_values
42
 
43
  # Generate caption using the BLIP model
44
+ output_ids = blip_model.generate(pixel_values, max_length=50, num_beams=4, early_stopping=True)
45
 
46
  # Decode the caption
47
+ caption = blip_processor.decode(output_ids[0], skip_special_tokens=True)
48
 
49
  return caption
50
 
51
+ def classify_image(image):
52
+ # Preprocess the image
53
+ inputs = clip_processor(images=image, return_tensors="pt")
54
+
55
+ # Classify the image using the CLIP model
56
+ with torch.no_grad():
57
+ logits_per_image = clip_model(**inputs)[0]
58
+ probs = logits_per_image.softmax(dim=-1)
59
+
60
+ # Get the top predicted label
61
+ top_prob, top_label = torch.max(probs, dim=-1)
62
+ top_label = labels[top_label.item()]
63
+
64
+ return top_label
65
+
66
  def main():
67
+ st.title("Image Caption and Classification")
68
 
69
  # Upload image
70
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
73
  # Display the uploaded image
74
  image = st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
75
 
76
+ # Generate caption and classify the image
77
+ if st.button("Generate Caption and Classify"):
78
+ with st.spinner("Processing image..."):
79
  caption = generate_caption(uploaded_file.getvalue())
80
+ top_label = classify_image(uploaded_file.getvalue())
81
  st.success(f"Caption: {caption}")
82
+ st.success(f"Top Predicted Label: {top_label}")
83
 
84
  if __name__ == "__main__":
85
  main()