|
import streamlit as st |
|
from tensorflow.keras.models import load_model |
|
from tensorflow.keras.preprocessing.text import Tokenizer |
|
from tensorflow.keras.preprocessing.sequence import pad_sequences |
|
from tensorflow.keras.applications.vgg16 import preprocess_input |
|
from tensorflow.keras.applications.vgg16 import VGG16 |
|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.preprocessing.image import load_img, img_to_array |
|
import numpy as np |
|
from PIL import Image |
|
from pickle import load |
|
|
|
|
|
tokenizer = load(open('tokenizer1.pkl', 'rb')) |
|
max_len = 34 |
|
|
|
|
|
model = load_model('model_18.h5') |
|
|
|
|
|
vgg_model = VGG16() |
|
vgg_model.layers.pop() |
|
vgg_model = Model(inputs=vgg_model.inputs, outputs=vgg_model.layers[-2].output) |
|
|
|
|
|
def word_for_id(integer, tokenizer): |
|
for word, index in tokenizer.word_index.items(): |
|
if index == integer: |
|
return word |
|
return None |
|
|
|
|
|
def generate_caption(model, tokenizer, photo, max_length): |
|
|
|
in_text = 'startseq' |
|
|
|
for i in range(max_length): |
|
|
|
sequence = tokenizer.texts_to_sequences([in_text])[0] |
|
|
|
sequence = pad_sequences([sequence], maxlen=max_length) |
|
|
|
yhat = model.predict([photo, sequence], verbose=0) |
|
|
|
yhat = np.argmax(yhat) |
|
|
|
word = word_for_id(yhat, tokenizer) |
|
|
|
if word is None: |
|
break |
|
|
|
in_text += ' ' + word |
|
|
|
if word == 'endseq': |
|
break |
|
return in_text |
|
|
|
|
|
def extract_features(filename): |
|
|
|
image = load_img(filename, target_size=(224, 224)) |
|
|
|
image = img_to_array(image) |
|
|
|
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2])) |
|
|
|
image = preprocess_input(image) |
|
|
|
feature = vgg_model.predict(image, verbose=0) |
|
return feature |
|
|
|
|
|
def remove_start_end_tokens(caption): |
|
stopwords = ['startseq', 'endseq'] |
|
querywords = caption.split() |
|
resultwords = [word for word in querywords if word.lower() not in stopwords] |
|
result = ' '.join(resultwords) |
|
return result |
|
|
|
def main(): |
|
st.set_page_config(page_title="Image Captioning", page_icon="📷") |
|
st.title("Image Captioning") |
|
st.markdown("Upload an image and get a caption for it.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
resized_image = image.resize((400, 400)) |
|
st.image(resized_image, caption='Uploaded Image') |
|
|
|
|
|
photo = extract_features(uploaded_file) |
|
|
|
|
|
if st.button("Generate Caption"): |
|
with st.spinner("Generating caption..."): |
|
description = generate_caption(model, tokenizer, photo, max_len) |
|
|
|
|
|
caption = remove_start_end_tokens(description) |
|
|
|
|
|
st.subheader(" Generated Caption") |
|
st.markdown("---") |
|
st.markdown(f"<p style='font-size: 18px; text-align: center;'>{caption}</p>", unsafe_allow_html=True) |
|
st.markdown("---") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|