Pclanglais commited on
Commit
61dc098
1 Parent(s): dfbcb2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -10,6 +10,7 @@ import shutil
10
  import requests
11
  import pandas as pd
12
  import difflib
 
13
 
14
  # OCR Correction Model
15
  ocr_model_name = "PleIAs/OCRonos-Vintage"
@@ -162,22 +163,26 @@ def split_text(text, max_tokens=500):
162
 
163
 
164
  # Function to generate text
165
- @spaces.GPU
166
- def ocr_correction(prompt, max_new_tokens=500):
167
- model.to(device)
168
-
169
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
170
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
171
 
 
 
 
172
  # Generate text
173
- output = model.generate(input_ids,
 
 
 
174
  max_new_tokens=max_new_tokens,
175
  pad_token_id=tokenizer.eos_token_id,
176
  top_k=50,
177
  num_return_sequences=1,
178
- do_sample=True,
179
- temperature=0.7
180
  )
 
 
181
  # Decode and return the generated text
182
  result = tokenizer.decode(output[0], skip_special_tokens=True)
183
  print(result)
 
10
  import requests
11
  import pandas as pd
12
  import difflib
13
+ from concurrent.futures import ThreadPoolExecutor
14
 
15
  # OCR Correction Model
16
  ocr_model_name = "PleIAs/OCRonos-Vintage"
 
163
 
164
 
165
  # Function to generate text
166
+ def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()):
 
 
 
167
  prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
168
  input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
169
 
170
+ # Set the number of threads for PyTorch
171
+ torch.set_num_threads(num_threads)
172
+
173
  # Generate text
174
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
175
+ future = executor.submit(
176
+ model.generate,
177
+ input_ids,
178
  max_new_tokens=max_new_tokens,
179
  pad_token_id=tokenizer.eos_token_id,
180
  top_k=50,
181
  num_return_sequences=1,
182
+ do_sample=False
 
183
  )
184
+ output = future.result()
185
+
186
  # Decode and return the generated text
187
  result = tokenizer.decode(output[0], skip_special_tokens=True)
188
  print(result)