kusumakar's picture
Update app.py
ff7c5de
#load all necessary libraries, Don't forget to check the system requirements or dependencies
import torch
import numpy as np
from PIL import Image
import streamlit as st
from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer, GPT2LMHeadModel
# Load the Model,feature extractor and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokeniser = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# define the function
def generate_captions(image):
generated_caption = tokeniser.decode(model.generate(extractor(image, return_tensors="pt").pixel_values.to("cpu"))[0])
sentence = generated_caption
text_to_remove = "<|endoftext|>"
generated_caption = sentence.replace(text_to_remove, "")
return generated_caption
# Load the pre-trained model and tokenizer
model_name = "gpt2"
tokenizer_2 = GPT2Tokenizer.from_pretrained(model_name)
model_2 = GPT2LMHeadModel.from_pretrained(model_name)
# Define the Function
def generate_paragraph(prompt):
# Tokenize the prompt
input_ids = tokenizer_2.encode(prompt, return_tensors="pt")
# Generate the paragraph
output = model_2.generate(input_ids, max_length=200, num_return_sequences=1,no_repeat_ngram_size=2, early_stopping=True)
# Decode the generated output into text
paragraph = tokenizer_2.decode(output[0], skip_special_tokens=True)
return paragraph.capitalize()
# Define the streamlit App
def main():
# Set Streamlit app title and description
st.title("Have a Picture! Don't Know how to Describe?. Here's Some Help")
st.subheader("Upload the Picture to get Catchy Description.")
# create file uploader
uploaded_file = st.file_uploader("Drag and Drop or Upload the picture", type=["jpg", "jpeg", "png"])
# check if file has been uploaded
if uploaded_file is not None:
# load the image
image = Image.open(uploaded_file).convert("RGB")
# context as prompt
prompt = generate_captions(image)
# display the image
st.image(uploaded_file)
# generate and display the description
generated_paragraph = generate_paragraph(prompt)
st.write(generated_paragraph)
if __name__ == "__main__":
main()