Spaces:
Sleeping
Sleeping
import spaces | |
import transformers | |
import re | |
import torch | |
import gradio as gr | |
import os | |
import ctranslate2 | |
import difflib | |
import shutil | |
import requests | |
from concurrent.futures import ThreadPoolExecutor | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load CTranslate2 model and tokenizer | |
model_path = "ocronos_ct2" | |
generator = ctranslate2.Generator(model_path, device=device) | |
tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage") | |
# CSS for formatting (unchanged) | |
# CSS for formatting | |
css = """ | |
<style> | |
.generation { | |
margin-left: 2em; | |
margin-right: 2em; | |
font-size: 1.2em; | |
} | |
:target { | |
background-color: #CCF3DF; | |
} | |
.source { | |
float: left; | |
max-width: 17%; | |
margin-left: 2%; | |
} | |
.tooltip { | |
position: relative; | |
cursor: pointer; | |
font-variant-position: super; | |
color: #97999b; | |
} | |
.tooltip:hover::after { | |
content: attr(data-text); | |
position: absolute; | |
left: 0; | |
top: 120%; | |
white-space: pre-wrap; | |
width: 500px; | |
max-width: 500px; | |
z-index: 1; | |
background-color: #f9f9f9; | |
color: #000; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 5px; | |
display: block; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
} | |
.deleted { | |
background-color: #ffcccb; | |
text-decoration: line-through; | |
} | |
.inserted { | |
background-color: #90EE90; | |
} | |
.manuscript { | |
display: flex; | |
margin-bottom: 10px; | |
align-items: baseline; | |
} | |
.annotation { | |
width: 15%; | |
padding-right: 20px; | |
color: grey !important; | |
font-style: italic; | |
text-align: right; | |
} | |
.content { | |
width: 80%; | |
} | |
h2 { | |
margin: 0; | |
font-size: 1.5em; | |
} | |
.title-content h2 { | |
font-weight: bold; | |
} | |
.bibliography-content { | |
color: darkgreen !important; | |
margin-top: -5px; | |
} | |
.paratext-content { | |
color: #a4a4a4 !important; | |
margin-top: -5px; | |
} | |
</style> | |
""" | |
# Helper functions | |
def generate_html_diff(old_text, new_text): | |
d = difflib.Differ() | |
diff = list(d.compare(old_text.split(), new_text.split())) | |
html_diff = [] | |
for word in diff: | |
if word.startswith(' '): | |
html_diff.append(word[2:]) | |
elif word.startswith('+ '): | |
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>') | |
return ' '.join(html_diff) | |
def preprocess_text(text): | |
text = re.sub(r'<[^>]+>', '', text) | |
text = re.sub(r'\n', ' ', text) | |
text = re.sub(r'\s+', ' ', text) | |
return text.strip() | |
def split_text(text, max_tokens=400): | |
encoded = tokenizer.encode(text) | |
splits = [] | |
for i in range(0, len(encoded), max_tokens): | |
split = encoded[i:i+max_tokens] | |
splits.append(tokenizer.decode(split)) | |
return splits | |
# Function to generate text using CTranslate2 | |
def ocr_correction(prompt, max_new_tokens=600): | |
splits = split_text(prompt, max_tokens=400) | |
corrected_splits = [] | |
list_prompts = [] | |
for split in splits: | |
full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n" | |
print(full_prompt) | |
encoded = tokenizer.encode(full_prompt) | |
prompt_tokens = tokenizer.convert_ids_to_tokens(encoded) | |
list_prompts.append(prompt_tokens) | |
results = generator.generate_batch( | |
list_prompts, | |
max_length=max_new_tokens, | |
sampling_temperature=0, | |
sampling_topk=20, | |
repetition_penalty=1.1, | |
include_prompt_in_result=False | |
) | |
for result in results: | |
corrected_text = tokenizer.decode(result.sequences_ids[0]) | |
corrected_splits.append(corrected_text) | |
return " ".join(corrected_splits) | |
# OCR Correction Class | |
class OCRCorrector: | |
def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
self.system_prompt = system_prompt | |
def correct(self, user_message): | |
generated_text = ocr_correction(user_message) | |
html_diff = generate_html_diff(user_message, generated_text) | |
return generated_text, html_diff | |
# Combined Processing Class | |
class TextProcessor: | |
def __init__(self): | |
self.ocr_corrector = OCRCorrector() | |
def process(self, user_message): | |
# OCR Correction | |
corrected_text, html_diff = self.ocr_corrector.correct(user_message) | |
# Combine results | |
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>' | |
final_output = f"{css}{ocr_result}" | |
return final_output | |
# Create the TextProcessor instance | |
text_processor = TextProcessor() | |
# Define the Gradio interface | |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""") | |
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5) | |
process_button = gr.Button("Process Text") | |
text_output = gr.HTML(label="Processed text") | |
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |