LayoutLMv3-1 / app.py
lamiaaEl's picture
Create app.py
968b8b2 verified
raw
history blame contribute delete
No virus
3.58 kB
import streamlit as st
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
from PIL import Image
import torch
import easyocr
from PIL import Image
import re
# Load the model and processor from Hugging Face
model_name = "capitaletech/LayoutLMv3-v1" # Replace with your model repository name
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
processor = LayoutLMv3Processor.from_pretrained(model_name)
st.title("LayoutLMv3 Text Extraction")
st.write("Upload an image to get text predictions using the fine-tuned LayoutLMv3 model.")
uploaded_file = st.file_uploader("Choose an image...", type="png")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image.', use_column_width=True)
st.write("")
st.write("Classifying...")
# Process the image
words = uploaded_file["tokens"]
boxes = uploaded_file["bboxes"]
word_labels = uploaded_file["ner_tags"]
encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predictions = logits.argmax(-1).squeeze().cpu.tolist()
labels = encoding['labels'].squeeze().tolist()
# Set up the EasyOCR reader for multiple languages
languages = ["ru", "rs_cyrillic", "be", "bg", "uk", "mn", "en"]
reader = easyocr.Reader(languages)
# Load the image
image_path = example["img_path"]
image = Image.open(image_path)
# Perform text detection
ocr_results = reader.readtext(image_path, detail=1)
# Extract text and bounding boxes, filtering non-alphabetic characters from text
words = []
boxes = []
# Define a regular expression pattern for non-alphabetic characters
non_alphabet_pattern = re.compile(r'[^a-zA-Z]+')
for result in ocr_results:
bbox, text, _ = result
filtered_text = re.sub(non_alphabet_pattern, '', text)
if filtered_text: # Only append if there are alphabetic characters left
words.append(filtered_text)
boxes.append([
bbox[0][0], bbox[0][1],
bbox[2][0], bbox[2][1]
])
words = words[1:]
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
token_boxes = encoding["bbox"].squeeze().tolist()
width, height = image.size
true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
true_tokens = words
# Associate languages with their levels
languages_with_levels = {}
current_language = None
j=0
for i in range(0, len(true_labels)):
if true_labels[i] == 'language':
current_language = words[j]
j= j+1
languages_with_levels[current_language] = true_labels[i+1]
print(languages_with_levels)
input_ids = encoding["input_ids"]
bbox = encoding["bbox"]
attention_mask = encoding["attention_mask"]
st.write("Predicted labels:")
# Print languages with their levels
for language, level in languages_with_levels.items():
st.write(f"{language}: {level}")