File size: 6,092 Bytes
cb4c41f f50ef7b dea123e 1592bb3 93fe32c cb4c41f 789ac0d cb4c41f 6b45627 5c58f13 6b45627 f50ef7b 21ab7aa 1592bb3 9485a97 1592bb3 4fa7e16 dea123e cb4c41f b7221a3 cb4c41f f4ccbdc 93fee22 e45a37b 93fee22 e45a37b df954f0 7384818 da509d5 7384818 f4ccbdc 93fee22 b7221a3 8732834 b7221a3 bc05740 b7221a3 bc05740 93fee22 bc05740 b7221a3 8732834 b7221a3 cb4c41f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os
processor = DonutProcessor.from_pretrained("./donut-base-finetuned-inv")
model = VisionEncoderDecoderModel.from_pretrained("./donut-base-finetuned-inv")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
#can't save uploaded file locally, but needs to be converted from nparray to PIL
im1 = Image.fromarray(image)
#send notification through telegram
TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID')
url = f'https://api.telegram.org/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}'
bio = BytesIO()
bio.name = 'image.jpeg'
im1.save(bio, 'JPEG')
bio.seek(0)
media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"}
data = {"media": json.dumps(media)}
response = requests.post(url, files={'photo': bio}, data=data)
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence), image
title = '<h1 style="text-align:center"><img alt="" src="circling_small.gif" />Welcome<img alt="" src="circling2_small.gif" /></h1>'
paragraph1 = '<p>Basic idea of this 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a> and the 🤗 <a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be 'Other'). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately. Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a NVIDIA RTX A4000 for 20k steps with a val_metric of 0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
# article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph3 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click 'submit', or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result)</em></p><p> </p><p>Have fun 😎</p><p>Toon Beerten</p>'
css = "#inp {height: auto !important; width: 100% !important;}"
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"
#css = ".image-preview {height: auto !important;}"
with gr.Blocks(css=css) as demo:
gr.HTML(title)
gr.HTML(paragraph1)
gr.HTML(paragraph2)
gr.HTML(paragraph3)
with gr.Row().style():
with gr.Column(scale=1):
inp = gr.Image(label='Upload invoice here:') #.style(height=400)
with gr.Column(scale=2):
gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
with gr.Row().style():
btn = gr.Button("↓ Extract ↓")
with gr.Row(css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}').style():
with gr.Column(scale=2):
imgout = gr.Image(label='Uploaded document:',elem_id="inp")
with gr.Column(scale=1):
jsonout = gr.JSON(label='Extracted information:')
btn.click(fn=process_document, inputs=inp, outputs=[jsonout,imgout])
demo.launch() |