Spaces:
Runtime error
Runtime error
ChandraP12330
commited on
Commit
•
c8b7141
1
Parent(s):
cb5683a
Update app.py
Browse files
app.py
CHANGED
@@ -1,24 +1,70 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import BlipForConditionalGeneration, BlipProcessor
|
|
|
3 |
|
4 |
# Load the BLIP model and processor
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
def generate_caption(image):
|
9 |
# Preprocess the image
|
10 |
-
pixel_values =
|
11 |
|
12 |
# Generate caption using the BLIP model
|
13 |
-
output_ids =
|
14 |
|
15 |
# Decode the caption
|
16 |
-
caption =
|
17 |
|
18 |
return caption
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def main():
|
21 |
-
st.title("Image Caption
|
22 |
|
23 |
# Upload image
|
24 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
@@ -27,11 +73,13 @@ def main():
|
|
27 |
# Display the uploaded image
|
28 |
image = st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
|
29 |
|
30 |
-
# Generate caption
|
31 |
-
if st.button("Generate Caption"):
|
32 |
-
with st.spinner("
|
33 |
caption = generate_caption(uploaded_file.getvalue())
|
|
|
34 |
st.success(f"Caption: {caption}")
|
|
|
35 |
|
36 |
if __name__ == "__main__":
|
37 |
main()
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import BlipForConditionalGeneration, BlipProcessor, CLIPProcessor, CLIPModel
|
3 |
+
import torch
|
4 |
|
5 |
# Load the BLIP model and processor
|
6 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
7 |
+
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
8 |
+
|
9 |
+
# Load the CLIP model and processor
|
10 |
+
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
11 |
+
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
12 |
+
|
13 |
+
# Labels for classification
|
14 |
+
labels=['Arrest',
|
15 |
+
'Arson',
|
16 |
+
'Explosion',
|
17 |
+
'public fight',
|
18 |
+
'Normal',
|
19 |
+
'Road Accident',
|
20 |
+
'Robbery',
|
21 |
+
'Shooting',
|
22 |
+
'Stealing',
|
23 |
+
'Vandalism',
|
24 |
+
'Suspicious activity',
|
25 |
+
'Tailgating',
|
26 |
+
'Unauthorized entry',
|
27 |
+
'Protest/Demonstration',
|
28 |
+
'Drone suspicious activity',
|
29 |
+
'Fire/Smoke detection',
|
30 |
+
'Medical emergency',
|
31 |
+
'Suspicious package/object',
|
32 |
+
'Threatening',
|
33 |
+
'Attack',
|
34 |
+
'Shoplifting',
|
35 |
+
'burglary ',
|
36 |
+
'distress',
|
37 |
+
'assault']
|
38 |
|
39 |
def generate_caption(image):
|
40 |
# Preprocess the image
|
41 |
+
pixel_values = blip_processor(images=image, return_tensors="pt").pixel_values
|
42 |
|
43 |
# Generate caption using the BLIP model
|
44 |
+
output_ids = blip_model.generate(pixel_values, max_length=50, num_beams=4, early_stopping=True)
|
45 |
|
46 |
# Decode the caption
|
47 |
+
caption = blip_processor.decode(output_ids[0], skip_special_tokens=True)
|
48 |
|
49 |
return caption
|
50 |
|
51 |
+
def classify_image(image):
|
52 |
+
# Preprocess the image
|
53 |
+
inputs = clip_processor(images=image, return_tensors="pt")
|
54 |
+
|
55 |
+
# Classify the image using the CLIP model
|
56 |
+
with torch.no_grad():
|
57 |
+
logits_per_image = clip_model(**inputs)[0]
|
58 |
+
probs = logits_per_image.softmax(dim=-1)
|
59 |
+
|
60 |
+
# Get the top predicted label
|
61 |
+
top_prob, top_label = torch.max(probs, dim=-1)
|
62 |
+
top_label = labels[top_label.item()]
|
63 |
+
|
64 |
+
return top_label
|
65 |
+
|
66 |
def main():
|
67 |
+
st.title("Image Caption and Classification")
|
68 |
|
69 |
# Upload image
|
70 |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
|
|
73 |
# Display the uploaded image
|
74 |
image = st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
|
75 |
|
76 |
+
# Generate caption and classify the image
|
77 |
+
if st.button("Generate Caption and Classify"):
|
78 |
+
with st.spinner("Processing image..."):
|
79 |
caption = generate_caption(uploaded_file.getvalue())
|
80 |
+
top_label = classify_image(uploaded_file.getvalue())
|
81 |
st.success(f"Caption: {caption}")
|
82 |
+
st.success(f"Top Predicted Label: {top_label}")
|
83 |
|
84 |
if __name__ == "__main__":
|
85 |
main()
|