Spaces:
Running
Running
File size: 4,658 Bytes
2493114 d185e95 2493114 b25b478 440913e 5fe63a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import numpy as np
import pandas as pd
import streamlit as st
from PIL import Image
import torch
import torch.nn.functional as F
import pytesseract
import plotly.express as px
from torch.utils.data import Dataset, DataLoader, Subset
import os
import io
import pytesseract
import fitz
from typing import List
import json
import sys
from pathlib import Path
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOKENIZER = "microsoft/layoutlmv3-base"
MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01"
TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract
@st.cache_resource
def create_ocr_reader():
def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0):
return [
int(box[0] * w_scale),
int(box[1] * h_scale),
int(box[2] * w_scale),
int(box[3] * h_scale)
]
def ocr_page(image) -> dict:
"""
OCR a given image. Return a dictionary of words and the bounding boxes
for each word. For each word, there is a corresponding bounding box.
"""
ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS)
ocr_df = ocr_df.dropna().reset_index(drop=True)
float_cols = ocr_df.select_dtypes('float').columns
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
ocr_df = ocr_df.dropna().reset_index(drop=True)
words = list(ocr_df.text)
words = [str(w) for w in words]
coordinates = ocr_df[['left', 'top', 'width', 'height']]
boxes = []
for i, row in coordinates.iterrows():
x, y, w, h = tuple(row)
actual_box = [x, y, x + w, y + h]
boxes.append(actual_box)
assert len(words) == len(boxes)
return {"bbox": boxes, "words": words}
def prepare_image(image):
ocr_data = ocr_page(image)
width, height = image.size
width_scale = 1000 / width
height_scale = 1000 / height
words = []
boxes = []
for w, b in zip(ocr_data["words"], ocr_data["bbox"]):
words.append(w)
boxes.append(scale_bounding_box(b, width_scale, height_scale))
assert len(words) == len(boxes)
for bo in boxes:
for z in bo:
if (z > 1000):
raise
return words, boxes
return prepare_image
@st.cache_resource
def create_model():
model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
return model.eval().to(DEVICE)
@st.cache_resource
def create_processor():
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER)
return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
words, boxes = reader(image)
encoding = processor(
image,
words,
boxes=boxes,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
with torch.inference_mode():
output = model(
input_ids=encoding["input_ids"].to(DEVICE),
attention_mask=encoding["attention_mask"].to(DEVICE),
bbox=encoding["bbox"].to(DEVICE),
pixel_values=encoding["pixel_values"].to(DEVICE)
)
logits = output.logits
predicted_class = logits.argmax()
probabilities = F.softmax(logits, dim=-1).flatten().tolist()
return predicted_class.detach().item(), probabilities
reader = create_ocr_reader()
processor = create_processor()
model = create_model()
uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"])
if uploaded_file is not None:
bytes_data = io.BytesIO(uploaded_file.read())
image = Image.open(bytes_data)
st.image(image, caption="Uploaded Image", use_column_width=True)
predicted, probabilities = predict(image, reader, processor, model)
predicted_label = model.config.id2label[predicted]
st.markdown(f"Predicted Label: {predicted_label}")
df = pd.DataFrame({
"Label": list(model.config.id2label.values()),
"Probability": probabilities
})
fig = px.bar(df, x="Label", y="Probability")
st.plotly_chart(fig, use_container_width=True)
|