File size: 4,658 Bytes
2493114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d185e95
2493114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b25b478
440913e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fe63a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import numpy as np
import pandas as pd
import streamlit as st

from PIL import Image

import torch
import torch.nn.functional as F
import pytesseract

import plotly.express as px

from torch.utils.data import Dataset, DataLoader,  Subset
import os
import io
import pytesseract
import fitz
from typing import List
import json

import sys
from pathlib import Path

from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TOKENIZER = "microsoft/layoutlmv3-base"
MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01"

TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract

@st.cache_resource
def create_ocr_reader():
    def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0):
        return [
            int(box[0] * w_scale),
            int(box[1] * h_scale),
            int(box[2] * w_scale),
            int(box[3] * h_scale)
        ]
    def ocr_page(image) -> dict:
        """
        OCR a given image. Return a dictionary of words and the bounding boxes
        for each word. For each word, there is a corresponding bounding box.
        """
        ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS)
        ocr_df = ocr_df.dropna().reset_index(drop=True)
        float_cols = ocr_df.select_dtypes('float').columns
        ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
        ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
        ocr_df = ocr_df.dropna().reset_index(drop=True)

        words = list(ocr_df.text)
        words = [str(w) for w in words]

        coordinates = ocr_df[['left', 'top', 'width', 'height']]
        boxes = []
        for i, row in coordinates.iterrows():
            x, y, w, h = tuple(row)
            actual_box = [x, y, x + w, y + h]
            boxes.append(actual_box)

        assert len(words) == len(boxes)
        return {"bbox": boxes, "words": words}
        
    def prepare_image(image):
        ocr_data = ocr_page(image)
        width, height = image.size
        width_scale = 1000 / width
        height_scale = 1000 / height
        words = []
        boxes = []
        for w, b in zip(ocr_data["words"], ocr_data["bbox"]):
            words.append(w)
            boxes.append(scale_bounding_box(b, width_scale, height_scale))

        assert len(words) == len(boxes)
        for bo in boxes:
            for z in bo:
                if (z > 1000):
                    raise
        return words, boxes

    return prepare_image

@st.cache_resource
def create_model():
    model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
    return model.eval().to(DEVICE)

@st.cache_resource
def create_processor():
    feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
    tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER)
    return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
    words, boxes = reader(image)
    encoding = processor(
        image,
        words,
        boxes=boxes,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    with torch.inference_mode():
        output = model(
            input_ids=encoding["input_ids"].to(DEVICE),
            attention_mask=encoding["attention_mask"].to(DEVICE),
            bbox=encoding["bbox"].to(DEVICE),
            pixel_values=encoding["pixel_values"].to(DEVICE)            
        )
        logits = output.logits
        predicted_class = logits.argmax()
        probabilities = F.softmax(logits, dim=-1).flatten().tolist()
        return predicted_class.detach().item(), probabilities

reader = create_ocr_reader()
processor = create_processor()
model = create_model()

uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"])
if uploaded_file is not None:
    bytes_data = io.BytesIO(uploaded_file.read())    
    image = Image.open(bytes_data)
    st.image(image, caption="Uploaded Image", use_column_width=True)
    predicted, probabilities = predict(image, reader, processor, model)
    predicted_label = model.config.id2label[predicted]
    st.markdown(f"Predicted Label: {predicted_label}")

    df = pd.DataFrame({
        "Label": list(model.config.id2label.values()),
        "Probability": probabilities
    })
    fig = px.bar(df, x="Label", y="Probability")
    st.plotly_chart(fig, use_container_width=True)