Spaces:
Runtime error
Runtime error
#load all necessary libraries, Don't forget to check the system requirements or dependencies | |
import torch | |
import numpy as np | |
from PIL import Image | |
import streamlit as st | |
from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer, GPT2LMHeadModel | |
# Load the Model,feature extractor and tokenizer | |
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokeniser = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
# define the function | |
def generate_captions(image): | |
generated_caption = tokeniser.decode(model.generate(extractor(image, return_tensors="pt").pixel_values.to("cpu"))[0]) | |
sentence = generated_caption | |
text_to_remove = "<|endoftext|>" | |
generated_caption = sentence.replace(text_to_remove, "") | |
return generated_caption | |
# Load the pre-trained model and tokenizer | |
model_name = "gpt2" | |
tokenizer_2 = GPT2Tokenizer.from_pretrained(model_name) | |
model_2 = GPT2LMHeadModel.from_pretrained(model_name) | |
# Define the Function | |
def generate_paragraph(prompt): | |
# Tokenize the prompt | |
input_ids = tokenizer_2.encode(prompt, return_tensors="pt") | |
# Generate the paragraph | |
output = model_2.generate(input_ids, max_length=200, num_return_sequences=1,no_repeat_ngram_size=2, early_stopping=True) | |
# Decode the generated output into text | |
paragraph = tokenizer_2.decode(output[0], skip_special_tokens=True) | |
return paragraph.capitalize() | |
# Define the streamlit App | |
def main(): | |
# Set Streamlit app title and description | |
st.title("Have a Picture! Don't Know how to Describe?. Here's Some Help") | |
st.subheader("Upload the Picture to get Catchy Description.") | |
# create file uploader | |
uploaded_file = st.file_uploader("Drag and Drop or Upload the picture", type=["jpg", "jpeg", "png"]) | |
# check if file has been uploaded | |
if uploaded_file is not None: | |
# load the image | |
image = Image.open(uploaded_file).convert("RGB") | |
# context as prompt | |
prompt = generate_captions(image) | |
# display the image | |
st.image(uploaded_file) | |
# generate and display the description | |
generated_paragraph = generate_paragraph(prompt) | |
st.write(generated_paragraph) | |
if __name__ == "__main__": | |
main() | |