Tonic's picture
Update app.py
23e47c7 verified
raw
history blame contribute delete
No virus
3.89 kB
import transformers
import re
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
from difflib import Differ
from concurrent.futures import ThreadPoolExecutor
import os
description = """# 🙋🏻‍♂️Welcome to Tonic's On-Device📲⌚🎅🏻OCR Corrector (CPU)
📲⌚🎅🏻OCRonos-Vintage is a small specialized model for OCR correction of cultural heritage archives pre-trained with llm.c. OCRonos-Vintage is only 124 million parameters. It can run easily on CPU or provide correction at scale on GPUs (>10k tokens/seconds) while providing a quality of correction comparable to GPT-4 or the llama version of OCRonos for English-speaking cultural archives.
### Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
model_name = "PleIAs/OCRonos-Vintage"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
def diff_texts(text1, text2):
d = Differ()
return [
(token[2:], token[0] if token[0] != " " else None)
for token in d.compare(text1.split(), text2.split())
]
def split_text(text, max_tokens=400):
tokens = tokenizer.tokenize(text)
chunks = []
current_chunk = []
for token in tokens:
current_chunk.append(token)
if len(current_chunk) >= max_tokens:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
current_chunk = []
if current_chunk:
chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
return chunks
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
torch.set_num_threads(num_threads)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future = executor.submit(
model.generate,
input_ids,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=50,
num_return_sequences=1,
do_sample=False
)
output = future.result()
result = tokenizer.decode(output[0], skip_special_tokens=True)
return result.split("### Correction ###")[1].strip()
def process_text(user_message):
chunks = split_text(user_message)
corrected_chunks = []
for chunk in chunks:
corrected_chunk = ocr_correction(chunk)
corrected_chunks.append(corrected_chunk)
corrected_text = ' '.join(corrected_chunks)
return diff_texts(user_message, corrected_text)
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown(description)
text_input = gr.Textbox(
label="↘️Enter 👁️OCR'ed Text Outputs Here",
info="""Hi there, ;fémy name à`gis tonic 45and i like to ride my vpotz""",
lines=5,
)
process_button = gr.Button("Correct using 📲⌚🎅🏻OCRonos")
text_output = gr.HighlightedText(
label="📲⌚🎅🏻OCRonos Correction:",
combine_adjacent=True,
show_legend=True,
color_map={"+": "green", "-": "red"}
)
process_button.click(process_text, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()