Spaces:
Running
Running
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import pytesseract | |
import plotly.express as px | |
from torch.utils.data import Dataset, DataLoader, Subset | |
import os | |
import io | |
import pytesseract | |
import fitz | |
from typing import List | |
import json | |
import sys | |
from pathlib import Path | |
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TOKENIZER = "microsoft/layoutlmv3-base" | |
MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01" | |
TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract | |
def create_ocr_reader(): | |
def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0): | |
return [ | |
int(box[0] * w_scale), | |
int(box[1] * h_scale), | |
int(box[2] * w_scale), | |
int(box[3] * h_scale) | |
] | |
def ocr_page(image) -> dict: | |
""" | |
OCR a given image. Return a dictionary of words and the bounding boxes | |
for each word. For each word, there is a corresponding bounding box. | |
""" | |
ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS) | |
ocr_df = ocr_df.dropna().reset_index(drop=True) | |
float_cols = ocr_df.select_dtypes('float').columns | |
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int) | |
ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True) | |
ocr_df = ocr_df.dropna().reset_index(drop=True) | |
words = list(ocr_df.text) | |
words = [str(w) for w in words] | |
coordinates = ocr_df[['left', 'top', 'width', 'height']] | |
boxes = [] | |
for i, row in coordinates.iterrows(): | |
x, y, w, h = tuple(row) | |
actual_box = [x, y, x + w, y + h] | |
boxes.append(actual_box) | |
assert len(words) == len(boxes) | |
return {"bbox": boxes, "words": words} | |
def prepare_image(image): | |
ocr_data = ocr_page(image) | |
width, height = image.size | |
width_scale = 1000 / width | |
height_scale = 1000 / height | |
words = [] | |
boxes = [] | |
for w, b in zip(ocr_data["words"], ocr_data["bbox"]): | |
words.append(w) | |
boxes.append(scale_bounding_box(b, width_scale, height_scale)) | |
assert len(words) == len(boxes) | |
for bo in boxes: | |
for z in bo: | |
if (z > 1000): | |
raise | |
return words, boxes | |
return prepare_image | |
def create_model(): | |
model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME) | |
return model.eval().to(DEVICE) | |
def create_processor(): | |
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False) | |
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER) | |
return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) | |
def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification): | |
words, boxes = reader(image) | |
encoding = processor( | |
image, | |
words, | |
boxes=boxes, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
) | |
with torch.inference_mode(): | |
output = model( | |
input_ids=encoding["input_ids"].to(DEVICE), | |
attention_mask=encoding["attention_mask"].to(DEVICE), | |
bbox=encoding["bbox"].to(DEVICE), | |
pixel_values=encoding["pixel_values"].to(DEVICE) | |
) | |
logits = output.logits | |
predicted_class = logits.argmax() | |
probabilities = F.softmax(logits, dim=-1).flatten().tolist() | |
return predicted_class.detach().item(), probabilities | |
reader = create_ocr_reader() | |
processor = create_processor() | |
model = create_model() | |
uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"]) | |
if uploaded_file is not None: | |
bytes_data = io.BytesIO(uploaded_file.read()) | |
image = Image.open(bytes_data) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
predicted, probabilities = predict(image, reader, processor, model) | |
predicted_label = model.config.id2label[predicted] | |
st.markdown(f"Predicted Label: {predicted_label}") | |
df = pd.DataFrame({ | |
"Label": list(model.config.id2label.values()), | |
"Probability": probabilities | |
}) | |
fig = px.bar(df, x="Label", y="Probability") | |
st.plotly_chart(fig, use_container_width=True) | |