File size: 6,092 Bytes
cb4c41f
 
 
 
 
f50ef7b
dea123e
1592bb3
93fe32c
 
 
cb4c41f
789ac0d
 
cb4c41f
 
 
 
 
6b45627
5c58f13
6b45627
f50ef7b
21ab7aa
 
 
1592bb3
 
9485a97
1592bb3
4fa7e16
 
 
dea123e
cb4c41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7221a3
cb4c41f
f4ccbdc
93fee22
 
e45a37b
 
93fee22
e45a37b
df954f0
7384818
 
 
da509d5
7384818
 
 
f4ccbdc
 
93fee22
 
 
b7221a3
8732834
 
b7221a3
bc05740
 
b7221a3
bc05740
93fee22
bc05740
b7221a3
 
 
8732834
b7221a3
cb4c41f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import re
import gradio as gr

import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os


processor = DonutProcessor.from_pretrained("./donut-base-finetuned-inv")
model = VisionEncoderDecoderModel.from_pretrained("./donut-base-finetuned-inv")

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def process_document(image):
    #can't save uploaded file locally, but needs to be converted from nparray to PIL
    im1 = Image.fromarray(image)
    
    #send notification through telegram
    TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
    CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID')
    url = f'https://api.telegram.org/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}'
    bio = BytesIO()
    bio.name = 'image.jpeg'
    im1.save(bio, 'JPEG')
    bio.seek(0)
    media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"}
    data = {"media": json.dumps(media)}
    response = requests.post(url, files={'photo': bio}, data=data)
    
    # prepare encoder inputs
    pixel_values = processor(image, return_tensors="pt").pixel_values
    
    # prepare decoder inputs
    task_prompt = "<s_cord-v2>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
          
    # generate answer
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    
    # postprocess
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()  # remove first task start token
    
    return processor.token2json(sequence), image

title = '<h1 style="text-align:center"><img alt="" src="circling_small.gif" />Welcome<img alt="" src="circling2_small.gif" /></h1>'
paragraph1 = '<p>Basic idea of this 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a>&nbsp;and the 🤗&nbsp;<a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be &#39;Other&#39;). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately.&nbsp;Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a&nbsp;NVIDIA RTX A4000 for 20k steps with a val_metric of&nbsp;0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
#    article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph3 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click &#39;submit&#39;, or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result)</em></p><p>&nbsp;</p><p>Have fun&nbsp;😎</p><p>Toon Beerten</p>'

css = "#inp {height: auto !important; width: 100% !important;}"
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"

#css = ".image-preview {height: auto !important;}"


with gr.Blocks(css=css) as demo:
    
    gr.HTML(title)
    gr.HTML(paragraph1)
    gr.HTML(paragraph2)
    gr.HTML(paragraph3)
    
    with gr.Row().style():
        with gr.Column(scale=1):
            inp = gr.Image(label='Upload invoice here:')   #.style(height=400)          
        with gr.Column(scale=2):
             gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
    with gr.Row().style():       
             btn = gr.Button("↓   Extract   ↓")
    with gr.Row(css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}').style():
        with gr.Column(scale=2):
            imgout = gr.Image(label='Uploaded document:',elem_id="inp")
        with gr.Column(scale=1):
            jsonout = gr.JSON(label='Extracted information:')
            
    btn.click(fn=process_document, inputs=inp, outputs=[jsonout,imgout])

demo.launch()