Spaces:
Sleeping
Sleeping
import transformers | |
import re | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import torch | |
import gradio as gr | |
from difflib import Differ | |
from concurrent.futures import ThreadPoolExecutor | |
import os | |
description = """# 🙋🏻♂️Welcome to Tonic's On-Device📲⌚🎅🏻OCR Corrector (CPU) | |
📲⌚🎅🏻OCRonos-Vintage is a small specialized model for OCR correction of cultural heritage archives pre-trained with llm.c. OCRonos-Vintage is only 124 million parameters. It can run easily on CPU or provide correction at scale on GPUs (>10k tokens/seconds) while providing a quality of correction comparable to GPT-4 or the llama version of OCRonos for English-speaking cultural archives. | |
### Join us : | |
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
""" | |
model_name = "PleIAs/OCRonos-Vintage" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = GPT2LMHeadModel.from_pretrained(model_name).to(device) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
def diff_texts(text1, text2): | |
d = Differ() | |
return [ | |
(token[2:], token[0] if token[0] != " " else None) | |
for token in d.compare(text1.split(), text2.split()) | |
] | |
def split_text(text, max_tokens=400): | |
tokens = tokenizer.tokenize(text) | |
chunks = [] | |
current_chunk = [] | |
for token in tokens: | |
current_chunk.append(token) | |
if len(current_chunk) >= max_tokens: | |
chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) | |
current_chunk = [] | |
if current_chunk: | |
chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) | |
return chunks | |
def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()): | |
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" | |
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) | |
torch.set_num_threads(num_threads) | |
with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
future = executor.submit( | |
model.generate, | |
input_ids, | |
max_new_tokens=max_new_tokens, | |
pad_token_id=tokenizer.eos_token_id, | |
top_k=50, | |
num_return_sequences=1, | |
do_sample=False | |
) | |
output = future.result() | |
result = tokenizer.decode(output[0], skip_special_tokens=True) | |
return result.split("### Correction ###")[1].strip() | |
def process_text(user_message): | |
chunks = split_text(user_message) | |
corrected_chunks = [] | |
for chunk in chunks: | |
corrected_chunk = ocr_correction(chunk) | |
corrected_chunks.append(corrected_chunk) | |
corrected_text = ' '.join(corrected_chunks) | |
return diff_texts(user_message, corrected_text) | |
with gr.Blocks(theme=gr.themes.Base()) as demo: | |
gr.MarkDown(description) | |
text_input = gr.Textbox( | |
label="↘️Enter 👁️OCR'ed Text Outputs Here", | |
info="""Hi there, ;fémy name à`gis tonic 45and i like to ride my vpotz""", | |
lines=5, | |
) | |
process_button = gr.Button("Correct using 📲⌚🎅🏻OCRonos") | |
text_output = gr.HighlightedText( | |
label="📲⌚🎅🏻OCRonos Correction:", | |
combine_adjacent=True, | |
show_legend=True, | |
color_map={"+": "green", "-": "red"} | |
) | |
process_button.click(process_text, inputs=text_input, outputs=[text_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |