Spaces:
Sleeping
Sleeping
Pclanglais
commited on
Commit
•
cd9ce00
1
Parent(s):
d731e09
Update app.py
Browse files
app.py
CHANGED
@@ -16,16 +16,19 @@ import difflib
|
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
# OCR Correction Model
|
19 |
-
ocr_model_name = "PleIAs/OCRonos"
|
20 |
-
ocr_llm = LLM(ocr_model_name, max_model_len=8128)
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
29 |
|
30 |
# CSS for formatting
|
31 |
css = """
|
@@ -163,30 +166,25 @@ def split_text(text, max_tokens=500):
|
|
163 |
|
164 |
return chunks
|
165 |
|
166 |
-
def transform_chunks(marianne_segmentation):
|
167 |
-
marianne_segmentation = pd.DataFrame(marianne_segmentation)
|
168 |
-
marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator']
|
169 |
-
marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False)
|
170 |
-
marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text)
|
171 |
-
marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')]
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content paratext-content">{word}</div></div>')
|
185 |
-
else:
|
186 |
-
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content">{word}</div></div>')
|
187 |
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
190 |
|
191 |
# OCR Correction Class
|
192 |
class OCRCorrector:
|
@@ -195,53 +193,25 @@ class OCRCorrector:
|
|
195 |
|
196 |
def correct(self, user_message):
|
197 |
sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
|
198 |
-
detailed_prompt = f"###
|
199 |
-
|
200 |
-
outputs = ocr_llm.generate(prompts, sampling_params, use_tqdm=False)
|
201 |
-
generated_text = outputs[0].outputs[0].text
|
202 |
html_diff = generate_html_diff(user_message, generated_text)
|
203 |
return generated_text, html_diff
|
204 |
|
205 |
-
# Editorial Segmentation Class
|
206 |
-
class EditorialSegmenter:
|
207 |
-
def segment(self, text):
|
208 |
-
editorial_text = re.sub("\n", " ¶ ", text)
|
209 |
-
num_tokens = len(tokenizer.tokenize(editorial_text))
|
210 |
-
|
211 |
-
if num_tokens > 500:
|
212 |
-
batch_prompts = split_text(editorial_text, max_tokens=500)
|
213 |
-
else:
|
214 |
-
batch_prompts = [editorial_text]
|
215 |
-
|
216 |
-
out = token_classifier(batch_prompts)
|
217 |
-
classified_list = []
|
218 |
-
for classification in out:
|
219 |
-
df = pd.DataFrame(classification)
|
220 |
-
classified_list.append(df)
|
221 |
-
|
222 |
-
classified_list = pd.concat(classified_list)
|
223 |
-
out = transform_chunks(classified_list)
|
224 |
-
return out
|
225 |
-
|
226 |
# Combined Processing Class
|
227 |
class TextProcessor:
|
228 |
def __init__(self):
|
229 |
self.ocr_corrector = OCRCorrector()
|
230 |
-
self.editorial_segmenter = EditorialSegmenter()
|
231 |
|
232 |
@spaces.GPU(duration=120)
|
233 |
def process(self, user_message):
|
234 |
-
#
|
235 |
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
|
236 |
|
237 |
-
# Step 2: Editorial Segmentation
|
238 |
-
segmented_text = self.editorial_segmenter.segment(corrected_text)
|
239 |
-
|
240 |
# Combine results
|
241 |
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
|
242 |
-
editorial_result = f'<h2 style="text-align:center">Editorial Segmentation</h2>\n<div class="generation">{segmented_text}</div>'
|
243 |
|
244 |
-
final_output = f"{css}{ocr_result}
|
245 |
return final_output
|
246 |
|
247 |
# Create the TextProcessor instance
|
@@ -249,7 +219,7 @@ text_processor = TextProcessor()
|
|
249 |
|
250 |
# Define the Gradio interface
|
251 |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
|
252 |
-
gr.HTML("""<h1 style="text-align:center">
|
253 |
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
|
254 |
process_button = gr.Button("Process Text")
|
255 |
text_output = gr.HTML(label="Processed text")
|
|
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
# OCR Correction Model
|
19 |
+
ocr_model_name = "PleIAs/OCRonos-Vintage"
|
|
|
20 |
|
21 |
+
import torch
|
22 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
23 |
+
|
24 |
+
# Load pre-trained model and tokenizer
|
25 |
+
model_name = "PleIAs/OCRonos-Vintage"
|
26 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
27 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
28 |
|
29 |
+
# Set the device to GPU if available, otherwise use CPU
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
model.to(device)
|
32 |
|
33 |
# CSS for formatting
|
34 |
css = """
|
|
|
166 |
|
167 |
return chunks
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
# Function to generate text
|
171 |
+
def ocr_correction(prompt, max_new_tokens=600):
|
172 |
+
|
173 |
+
prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
|
174 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
175 |
+
|
176 |
+
# Generate text
|
177 |
+
output = model.generate(input_ids,
|
178 |
+
max_new_tokens=max_new_tokens,
|
179 |
+
pad_token_id=tokenizer.eos_token_id,
|
180 |
+
top_k=50)
|
|
|
|
|
|
|
181 |
|
182 |
+
# Decode and return the generated text
|
183 |
+
result = tokenizer.decode(output[0], skip_special_tokens=True)
|
184 |
+
|
185 |
+
result = result.split("### Correction ###\n")[1]
|
186 |
+
|
187 |
+
return result
|
188 |
|
189 |
# OCR Correction Class
|
190 |
class OCRCorrector:
|
|
|
193 |
|
194 |
def correct(self, user_message):
|
195 |
sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
|
196 |
+
detailed_prompt = f"### Text ###\n{user_message}\n\n### Correction ###\n"
|
197 |
+
generated_text = ocr_correction(detailed_prompt)
|
|
|
|
|
198 |
html_diff = generate_html_diff(user_message, generated_text)
|
199 |
return generated_text, html_diff
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
# Combined Processing Class
|
202 |
class TextProcessor:
|
203 |
def __init__(self):
|
204 |
self.ocr_corrector = OCRCorrector()
|
|
|
205 |
|
206 |
@spaces.GPU(duration=120)
|
207 |
def process(self, user_message):
|
208 |
+
#OCR Correction
|
209 |
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
|
210 |
|
|
|
|
|
|
|
211 |
# Combine results
|
212 |
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
|
|
|
213 |
|
214 |
+
final_output = f"{css}{ocr_result}"
|
215 |
return final_output
|
216 |
|
217 |
# Create the TextProcessor instance
|
|
|
219 |
|
220 |
# Define the Gradio interface
|
221 |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
|
222 |
+
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
|
223 |
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
|
224 |
process_button = gr.Button("Process Text")
|
225 |
text_output = gr.HTML(label="Processed text")
|