File size: 3,581 Bytes
968b8b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor
from PIL import Image
import torch
import easyocr
from PIL import Image
import re

# Load the model and processor from Hugging Face
model_name = "capitaletech/LayoutLMv3-v1"  # Replace with your model repository name
model = LayoutLMv3ForTokenClassification.from_pretrained(model_name)
processor = LayoutLMv3Processor.from_pretrained(model_name)

st.title("LayoutLMv3 Text Extraction")
st.write("Upload an image to get text predictions using the fine-tuned LayoutLMv3 model.")

uploaded_file = st.file_uploader("Choose an image...", type="png")

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption='Uploaded Image.', use_column_width=True)
    st.write("")
    st.write("Classifying...")

    # Process the image
    words = uploaded_file["tokens"]
    boxes = uploaded_file["bboxes"]
    word_labels = uploaded_file["ner_tags"]

    encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**encoding)

    logits = outputs.logits
    predictions = logits.argmax(-1).squeeze().cpu.tolist()
    labels = encoding['labels'].squeeze().tolist()


    # Set up the EasyOCR reader for multiple languages
    languages = ["ru", "rs_cyrillic", "be", "bg", "uk", "mn", "en"]
    reader = easyocr.Reader(languages)

    # Load the image
    image_path = example["img_path"]
    image = Image.open(image_path)

    # Perform text detection
    ocr_results = reader.readtext(image_path, detail=1)

    # Extract text and bounding boxes, filtering non-alphabetic characters from text
    words = []
    boxes = []

    # Define a regular expression pattern for non-alphabetic characters
    non_alphabet_pattern = re.compile(r'[^a-zA-Z]+')

    for result in ocr_results:
        bbox, text, _ = result
        filtered_text = re.sub(non_alphabet_pattern, '', text)
        if filtered_text:  # Only append if there are alphabetic characters left
            words.append(filtered_text)
            boxes.append([
                bbox[0][0], bbox[0][1],
                bbox[2][0], bbox[2][1]
            ])

    words = words[1:]


    def unnormalize_box(bbox, width, height):
        return [
         width * (bbox[0] / 1000),
         height * (bbox[1] / 1000),
         width * (bbox[2] / 1000),
         height * (bbox[3] / 1000),
         ]

    token_boxes = encoding["bbox"].squeeze().tolist()
    width, height = image.size

    true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
    true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
    true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]
    true_tokens = words

    # Associate languages with their levels
    languages_with_levels = {}
    current_language = None

    j=0

    for i in range(0, len(true_labels)):
        if true_labels[i] == 'language':
            current_language = words[j]
            j= j+1
            languages_with_levels[current_language] = true_labels[i+1]

    print(languages_with_levels)

    input_ids = encoding["input_ids"]
    bbox = encoding["bbox"]
    attention_mask = encoding["attention_mask"]

    st.write("Predicted labels:")
    # Print languages with their levels
    for language, level in languages_with_levels.items():
        st.write(f"{language}: {level}")