import spaces import transformers import re from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline from vllm import LLM, SamplingParams import torch import gradio as gr import json import os import shutil import requests import pandas as pd import difflib # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # OCR Correction Model ocr_model_name = "PleIAs/OCRonos-Vintage" import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer # Load pre-trained model and tokenizer model_name = "PleIAs/OCRonos-Vintage" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) # Set the device to GPU if available, otherwise use CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # CSS for formatting css = """ """ # Helper functions def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def preprocess_text(text): text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\n', ' ', text) text = re.sub(r'\s+', ' ', text) return text.strip() def split_text(text, max_tokens=500): parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: if current_chunk: temp_chunk = current_chunk + "\n" + part else: temp_chunk = part num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks # Function to generate text def ocr_correction(prompt, max_new_tokens=600): prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Generate text output = model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=50) # Decode and return the generated text result = tokenizer.decode(output[0], skip_special_tokens=True) print(result) result = result.split("### Correction ###")[1] return result # OCR Correction Class class OCRCorrector: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def correct(self, user_message): sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"]) detailed_prompt = f"### Text ###\n{user_message}\n\n### Correction ###\n" generated_text = ocr_correction(detailed_prompt) html_diff = generate_html_diff(user_message, generated_text) return generated_text, html_diff # Combined Processing Class class TextProcessor: def __init__(self): self.ocr_corrector = OCRCorrector() @spaces.GPU(duration=120) def process(self, user_message): #OCR Correction corrected_text, html_diff = self.ocr_corrector.correct(user_message) # Combine results ocr_result = f'

OCR Correction

\n
{html_diff}
' final_output = f"{css}{ocr_result}" return final_output # Create the TextProcessor instance text_processor = TextProcessor() # Define the Gradio interface with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: gr.HTML("""

Vintage OCR corrector

""") text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) process_button = gr.Button("Process Text") text_output = gr.HTML(label="Processed text") process_button.click(text_processor.process, inputs=text_input, outputs=[text_output]) if __name__ == "__main__": demo.queue().launch()