Spaces:
Runtime error
Runtime error
import numpy as np | |
from PIL import Image | |
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
from transformers import AutoTokenizer, VisionEncoderDecoderModel, ViTFeatureExtractor | |
# Load the Model,feature extractor and tokenizer | |
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
tokeniser = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
def generate_captions(image): | |
generated_caption = tokeniser.decode(model.generate(extractor(image, return_tensors="pt").pixel_values.to("cpu"))[0]) | |
sentence = generated_caption | |
text_to_remove = "<|endoftext|>" | |
generated_caption = sentence.replace(text_to_remove, "") | |
return generated_caption | |
# Load the pre-trained model and tokenizer | |
model_name = "gpt2" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
# Define the Streamlit app | |
def generate_paragraph(prompt): | |
# Tokenize the prompt | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate the paragraph | |
output = model.generate(input_ids, max_length=200, num_return_sequences=1, early_stopping=True) | |
# Decode the generated output into text | |
paragraph = tokenizer.decode(output[0], skip_special_tokens=True) | |
return paragraph | |
# Streamlit app | |
def main(): | |
# Set Streamlit app title and description | |
st.title("Paragraph Generation From Context of an Image") | |
st.subheader("Upload the Image to generate a paragraph.") | |
# create file uploader | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
# check if file has been uploaded | |
if uploaded_file is not None: | |
# load the image | |
image = Image.open(uploaded_file).convert("RGB") | |
# context as prompt | |
prompt = generate_captions(image) | |
st.write("The Context is:", prompt) | |
# display the image | |
st.image(uploaded_file) | |
generated_paragraph = generate_paragraph(prompt) | |
st.write(generated_paragraph) | |
if __name__ == "__main__": | |
main() | |