File size: 3,886 Bytes
0dfb412
f2019a4
e98a756
f2019a4
 
9626102
61dc098
e98a756
31559f1
58f99ee
9626102
 
 
 
 
 
e98a756
dfbcb2e
e98a756
ffbf266
 
9626102
 
 
 
 
 
0dfb412
e98a756
 
ffbf266
e98a756
ffbf266
e98a756
 
 
 
 
ffbf266
 
e98a756
ffbf266
 
 
c4873ef
ffbf266
 
 
61dc098
 
 
 
 
 
ffbf266
 
 
 
61dc098
ffbf266
61dc098
 
ffbf266
e98a756
 
 
 
 
 
 
 
 
 
 
9626102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e98a756
7468778
0dfb412
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
from difflib import Differ
from concurrent.futures import ThreadPoolExecutor
import os

description = """# 🙋🏻‍♂️Welcome to Tonic's On-Device📲⌚🎅🏻OCR Corrector (CPU)
    📲⌚🎅🏻OCRonos-Vintage is a small specialized model for OCR correction of cultural heritage archives pre-trained with llm.c. OCRonos-Vintage is only 124 million parameters. It can run easily on CPU or provide correction at scale on GPUs (>10k tokens/seconds) while providing a quality of correction comparable to GPT-4 or the llama version of OCRonos for English-speaking cultural archives.

    ### Join us : 
    🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
    """

model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

def diff_texts(text1, text2):
    d = Differ()
    return [
        (token[2:], token[0] if token[0] != " " else None)
        for token in d.compare(text1.split(), text2.split())
    ]

def split_text(text, max_tokens=400):
    tokens = tokenizer.tokenize(text)
    chunks = []
    current_chunk = []

    for token in tokens:
        current_chunk.append(token)
        if len(current_chunk) >= max_tokens:
            chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
            current_chunk = []

    if current_chunk:
        chunks.append(tokenizer.convert_tokens_to_string(current_chunk))

    return chunks

def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
    prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    torch.set_num_threads(num_threads)

    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        future = executor.submit(
            model.generate,
            input_ids,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.eos_token_id,
            top_k=50,
            num_return_sequences=1,
            do_sample=False
        )
        output = future.result()

    result = tokenizer.decode(output[0], skip_special_tokens=True)
    return result.split("### Correction ###")[1].strip()

def process_text(user_message):
    chunks = split_text(user_message)
    corrected_chunks = []

    for chunk in chunks:
        corrected_chunk = ocr_correction(chunk)
        corrected_chunks.append(corrected_chunk)

    corrected_text = ' '.join(corrected_chunks)
    return diff_texts(user_message, corrected_text)

with gr.Blocks(theme=gr.themes.Base()) as demo:
    gr.MarkDown(description)
    text_input = gr.Textbox(
        label="↘️Enter 👁️OCR'ed Text Outputs Here",
        info="""Hi there, ;fémy name à`gis tonic 45and i like to ride my vpotz""",
        lines=5,
    )
    process_button = gr.Button("Correct using 📲⌚🎅🏻OCRonos")
    text_output = gr.HighlightedText(
        label="📲⌚🎅🏻OCRonos Correction:",
        combine_adjacent=True,
        show_legend=True,
        color_map={"+": "green", "-": "red"}
    )
    process_button.click(process_text, inputs=text_input, outputs=[text_output])

if __name__ == "__main__":
    demo.queue().launch()