File size: 3,702 Bytes
bc12901
 
 
 
1af0b6d
bc12901
 
 
 
bc6a638
 
bc12901
 
 
 
 
 
 
bc6a638
8171e8e
 
 
 
 
 
bc6a638
8171e8e
1af0b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc12901
bc6a638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2919076
bc6a638
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc12901
 
bc6a638
 
 
1af0b6d
 
bc12901
 
1af0b6d
 
8171e8e
bc12901
1af0b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc12901
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from PIL import ImageDraw
import streamlit as st

import torch
from docquery.pipeline import get_pipeline
from docquery.document import load_bytes, load_document


def ensure_list(x):
    if isinstance(x, list):
        return x
    else:
        return [x]


@st.experimental_singleton
def construct_pipeline():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    ret = get_pipeline(device=device)
    return ret


@st.cache
def run_pipeline(question, document, top_k):
    return construct_pipeline()(question=question, **document.context, top_k=top_k)


# TODO: Move into docquery
# TODO: Support words past the first page (or window?)
def lift_word_boxes(document):
    return document.context["image"][0][1]


def expand_bbox(word_boxes):
    if len(word_boxes) == 0:
        return None

    min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
    return [min(min_x), min(min_y), max(max_x), max(max_y)]


# LayoutLM boxes are normalized to 0, 1000
def normalize_bbox(box, width, height):
    pct = [c / 1000 for c in box]
    return [pct[0] * width, pct[1] * height, pct[2] * width, pct[3] * height]


st.markdown("# DocQuery: Query Documents w/ NLP")

if "document" not in st.session_state:
    st.session_state["document"] = None

input_type = st.radio("Pick an input type", ["Upload", "URL"], horizontal=True)


def load_file_cb():
    if st.session_state.file_input is None:
        return

    file = st.session_state.file_input
    with loading_placeholder:
        with st.spinner("Processing..."):
            document = load_bytes(file, file.name)
            _ = document.context
            st.session_state.document = document


def load_url_cb():
    if st.session_state.url_input is None:
        return

    url = st.session_state.url_input
    with loading_placeholder:
        with st.spinner("Downloading..."):
            document = load_document(url)
        with st.spinner("Processing..."):
            _ = document.context
        st.session_state.document = document


if input_type == "Upload":
    file = st.file_uploader(
        "Upload a PDF or Image document", key="file_input", on_change=load_file_cb
    )

elif input_type == "URL":
    # url = st.text_input("URL", "", on_change=load_url_callback, key="url_input")
    url = st.text_input("URL", "", key="url_input", on_change=load_url_cb)

question = st.text_input("QUESTION", "")

document = st.session_state.document
loading_placeholder = st.empty()
if document is not None:
    col1, col2 = st.columns([3, 1])
    image = document.preview


colors = ["blue", "red", "green"]
if document is not None and question is not None and len(question) > 0:
    col2.header("Answers")

    predictions = run_pipeline(question=question, document=document, top_k=1)

    word_boxes = lift_word_boxes(document)
    image = image.copy()
    draw = ImageDraw.Draw(image)
    for i, p in enumerate(ensure_list(predictions)):
        col2.markdown(f"#### { p['answer'] }: ({round(p['score'] * 100, 1)}%)")
        x1, y1, x2, y2 = normalize_bbox(
            expand_bbox(word_boxes[p["start"] : p["end"] + 1]),
            image.width,
            image.height,
        )
        draw.rectangle(((x1, y1), (x2, y2)), outline=colors[i])

if document is not None:
    col1.image(image, use_column_width=True)

"DocQuery uses LayoutLMv1 fine-tuned on DocVQA, a document visual question answering dataset, as well as SQuAD, which boosts its English-language comprehension. To use it, simply upload an image or PDF, type a question, and click 'submit', or click one of the examples to load them."

"[Github Repo](https://github.com/impira/docquery)"