kusumakar's picture
Update app.py
f35602f
raw
history blame
2.21 kB
import numpy as np
from PIL import Image
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor
# Load the Model,feature extractor and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokeniser = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def generate_captions(image):
generated_caption = tokeniser.decode(model.generate(extractor(image, return_tensors="pt").pixel_values.to("cpu"))[0])
sentence = generated_caption
text_to_remove = "<|endoftext|>"
generated_caption = sentence.replace(text_to_remove, "")
return generated_caption
# Load the pre-trained model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Define the Streamlit app
def generate_paragraph(prompt):
# Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Generate the paragraph
output = model.generate(input_ids, max_length=200, num_return_sequences=1, early_stopping=True)
# Decode the generated output into text
paragraph = tokenizer.decode(output[0], skip_special_tokens=True)
return paragraph
# Streamlit app
def main():
# Set Streamlit app title and description
st.title("Paragraph Generation From Context of an Image")
st.subheader("Upload the Image to generate a paragraph.")
# create file uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# check if file has been uploaded
if uploaded_file is not None:
# load the image
image = Image.open(uploaded_file).convert("RGB")
# context as prompt
prompt = generate_captions(image)
st.write("The Context is:", prompt)
# display the image
st.image(uploaded_file)
generated_paragraph = generate_paragraph(prompt)
st.write(generated_paragraph)
if __name__ == "__main__":
main()