import transformers import re from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch import gradio as gr from difflib import Differ from concurrent.futures import ThreadPoolExecutor import os description = """# 🙋🏻‍♂️Welcome to Tonic's On-Device📲⌚🎅🏻OCR Corrector (CPU) 📲⌚🎅🏻OCRonos-Vintage is a small specialized model for OCR correction of cultural heritage archives pre-trained with llm.c. OCRonos-Vintage is only 124 million parameters. It can run easily on CPU or provide correction at scale on GPUs (>10k tokens/seconds) while providing a quality of correction comparable to GPT-4 or the llama version of OCRonos for English-speaking cultural archives. ### Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 """ model_name = "PleIAs/OCRonos-Vintage" device = "cuda" if torch.cuda.is_available() else "cpu" model = GPT2LMHeadModel.from_pretrained(model_name).to(device) tokenizer = GPT2Tokenizer.from_pretrained(model_name) def diff_texts(text1, text2): d = Differ() return [ (token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1.split(), text2.split()) ] def split_text(text, max_tokens=400): tokens = tokenizer.tokenize(text) chunks = [] current_chunk = [] for token in tokens: current_chunk.append(token) if len(current_chunk) >= max_tokens: chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) current_chunk = [] if current_chunk: chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) return chunks def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()): prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) torch.set_num_threads(num_threads) with ThreadPoolExecutor(max_workers=num_threads) as executor: future = executor.submit( model.generate, input_ids, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=50, num_return_sequences=1, do_sample=False ) output = future.result() result = tokenizer.decode(output[0], skip_special_tokens=True) return result.split("### Correction ###")[1].strip() def process_text(user_message): chunks = split_text(user_message) corrected_chunks = [] for chunk in chunks: corrected_chunk = ocr_correction(chunk) corrected_chunks.append(corrected_chunk) corrected_text = ' '.join(corrected_chunks) return diff_texts(user_message, corrected_text) with gr.Blocks(theme=gr.themes.Base()) as demo: gr.Markdown(description) text_input = gr.Textbox( label="↘️Enter 👁️OCR'ed Text Outputs Here", info="""Hi there, ;fémy name à`gis tonic 45and i like to ride my vpotz""", lines=5, ) process_button = gr.Button("Correct using 📲⌚🎅🏻OCRonos") text_output = gr.HighlightedText( label="📲⌚🎅🏻OCRonos Correction:", combine_adjacent=True, show_legend=True, color_map={"+": "green", "-": "red"} ) process_button.click(process_text, inputs=text_input, outputs=[text_output]) if __name__ == "__main__": demo.queue().launch()