Pclanglais's picture
Update app.py
e103855 verified
raw
history blame
6.3 kB
import spaces
import transformers
import re
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import pandas as pd
import difflib
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# OCR Correction Model
ocr_model_name = "PleIAs/OCRonos-Vintage"
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load pre-trained model and tokenizer
model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float: left;
max-width: 17%;
margin-left: 2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.deleted {
background-color: #ffcccb;
text-decoration: line-through;
}
.inserted {
background-color: #90EE90;
}
.manuscript {
display: flex;
margin-bottom: 10px;
align-items: baseline;
}
.annotation {
width: 15%;
padding-right: 20px;
color: grey !important;
font-style: italic;
text-align: right;
}
.content {
width: 80%;
}
h2 {
margin: 0;
font-size: 1.5em;
}
.title-content h2 {
font-weight: bold;
}
.bibliography-content {
color: darkgreen !important;
margin-top: -5px;
}
.paratext-content {
color: #a4a4a4 !important;
margin-top: -5px;
}
</style>
"""
# Helper functions
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
return ' '.join(html_diff)
def preprocess_text(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'\n', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def split_text(text, max_tokens=500):
parts = text.split("\n")
chunks = []
current_chunk = ""
for part in parts:
if current_chunk:
temp_chunk = current_chunk + "\n" + part
else:
temp_chunk = part
num_tokens = len(tokenizer.tokenize(temp_chunk))
if num_tokens <= max_tokens:
current_chunk = temp_chunk
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = part
if current_chunk:
chunks.append(current_chunk)
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
long_text = chunks[0]
chunks = []
while len(tokenizer.tokenize(long_text)) > max_tokens:
split_point = len(long_text) // 2
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
split_point += 1
if split_point >= len(long_text):
split_point = len(long_text) - 1
chunks.append(long_text[:split_point].strip())
long_text = long_text[split_point:].strip()
if long_text:
chunks.append(long_text)
return chunks
# Function to generate text
def ocr_correction(prompt, max_new_tokens=600):
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Generate text
output = model.generate(input_ids,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=50)
# Decode and return the generated text
result = tokenizer.decode(output[0], skip_special_tokens=True)
print(result)
result = result.split("### Correction ###")[1]
return result
# OCR Correction Class
class OCRCorrector:
def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
self.system_prompt = system_prompt
def correct(self, user_message):
sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
detailed_prompt = f"### Text ###\n{user_message}\n\n### Correction ###\n"
generated_text = ocr_correction(detailed_prompt)
html_diff = generate_html_diff(user_message, generated_text)
return generated_text, html_diff
# Combined Processing Class
class TextProcessor:
def __init__(self):
self.ocr_corrector = OCRCorrector()
@spaces.GPU(duration=120)
def process(self, user_message):
#OCR Correction
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
# Combine results
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Create the TextProcessor instance
text_processor = TextProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()