Pclanglais commited on
Commit
ffbf266
1 Parent(s): 2814dfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -43
app.py CHANGED
@@ -1,24 +1,36 @@
1
  import spaces
2
  import transformers
3
  import re
 
 
4
  import torch
5
  import gradio as gr
 
6
  import os
7
- import ctranslate2
8
- import difflib
9
  import shutil
10
  import requests
 
 
11
  from concurrent.futures import ThreadPoolExecutor
12
 
13
  # Define the device
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
- # Load CTranslate2 model and tokenizer
17
- model_path = "ocronos_ct2"
18
- generator = ctranslate2.Generator(model_path, device=device)
19
- tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")
 
 
 
 
 
 
 
 
 
 
20
 
21
- # CSS for formatting (unchanged)
22
  # CSS for formatting
23
  css = """
24
  <style>
@@ -117,41 +129,73 @@ def preprocess_text(text):
117
  return text.strip()
118
 
119
  def split_text(text, max_tokens=500):
120
- encoded = tokenizer.encode(text)
121
- splits = []
122
- for i in range(0, len(encoded), max_tokens):
123
- split = encoded[i:i+max_tokens]
124
- splits.append(tokenizer.decode(split))
125
- return splits
126
-
127
- # Function to generate text using CTranslate2
128
- def ocr_correction(prompt, max_new_tokens=500):
129
- splits = split_text(prompt, max_tokens=500)
130
- corrected_splits = []
131
-
132
- list_prompts = []
133
-
134
- for split in splits:
135
- full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
136
- print(full_prompt)
137
- encoded = tokenizer.encode(full_prompt)
138
- prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
139
- list_prompts.append(prompt_tokens)
140
-
141
- results = generator.generate_batch(
142
- list_prompts,
143
- max_length=max_new_tokens,
144
- sampling_temperature=0,
145
- sampling_topk=20,
146
- repetition_penalty=1.1,
147
- include_prompt_in_result=False
148
- )
149
-
150
- for result in results:
151
- corrected_text = tokenizer.decode(result.sequences_ids[0])
152
- corrected_splits.append(corrected_text)
153
-
154
- return " ".join(corrected_splits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  # OCR Correction Class
157
  class OCRCorrector:
@@ -170,7 +214,7 @@ class TextProcessor:
170
 
171
  @spaces.GPU(duration=120)
172
  def process(self, user_message):
173
- # OCR Correction
174
  corrected_text, html_diff = self.ocr_corrector.correct(user_message)
175
 
176
  # Combine results
 
1
  import spaces
2
  import transformers
3
  import re
4
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
5
+ from vllm import LLM, SamplingParams
6
  import torch
7
  import gradio as gr
8
+ import json
9
  import os
 
 
10
  import shutil
11
  import requests
12
+ import pandas as pd
13
+ import difflib
14
  from concurrent.futures import ThreadPoolExecutor
15
 
16
  # Define the device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ # OCR Correction Model
20
+ ocr_model_name = "PleIAs/OCRonos-Vintage"
21
+
22
+ import torch
23
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
24
+
25
+ # Load pre-trained model and tokenizer
26
+ model_name = "PleIAs/OCRonos-Vintage"
27
+ model = GPT2LMHeadModel.from_pretrained(model_name)
28
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
29
+
30
+ # Set the device to GPU if available, otherwise use CPU
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ model.to(device)
33
 
 
34
  # CSS for formatting
35
  css = """
36
  <style>
 
129
  return text.strip()
130
 
131
  def split_text(text, max_tokens=500):
132
+ parts = text.split("\n")
133
+ chunks = []
134
+ current_chunk = ""
135
+
136
+ for part in parts:
137
+ if current_chunk:
138
+ temp_chunk = current_chunk + "\n" + part
139
+ else:
140
+ temp_chunk = part
141
+
142
+ num_tokens = len(tokenizer.tokenize(temp_chunk))
143
+
144
+ if num_tokens <= max_tokens:
145
+ current_chunk = temp_chunk
146
+ else:
147
+ if current_chunk:
148
+ chunks.append(current_chunk)
149
+ current_chunk = part
150
+
151
+ if current_chunk:
152
+ chunks.append(current_chunk)
153
+
154
+ if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
155
+ long_text = chunks[0]
156
+ chunks = []
157
+ while len(tokenizer.tokenize(long_text)) > max_tokens:
158
+ split_point = len(long_text) // 2
159
+ while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
160
+ split_point += 1
161
+ if split_point >= len(long_text):
162
+ split_point = len(long_text) - 1
163
+ chunks.append(long_text[:split_point].strip())
164
+ long_text = long_text[split_point:].strip()
165
+ if long_text:
166
+ chunks.append(long_text)
167
+
168
+ return chunks
169
+
170
+
171
+ # Function to generate text
172
+ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
173
+ prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
174
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
175
+
176
+ # Set the number of threads for PyTorch
177
+ torch.set_num_threads(num_threads)
178
+
179
+ # Generate text
180
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
181
+ future = executor.submit(
182
+ model.generate,
183
+ input_ids,
184
+ max_new_tokens=max_new_tokens,
185
+ pad_token_id=tokenizer.eos_token_id,
186
+ top_k=50,
187
+ num_return_sequences=1,
188
+ do_sample=True,
189
+ temperature=0.7
190
+ )
191
+ output = future.result()
192
+
193
+ # Decode and return the generated text
194
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
195
+ print(result)
196
+
197
+ result = result.split("### Correction ###")[1]
198
+ return result
199
 
200
  # OCR Correction Class
201
  class OCRCorrector:
 
214
 
215
  @spaces.GPU(duration=120)
216
  def process(self, user_message):
217
+ #OCR Correction
218
  corrected_text, html_diff = self.ocr_corrector.correct(user_message)
219
 
220
  # Combine results