Pclanglais commited on
Commit
10a4171
1 Parent(s): d55b86a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -129,19 +129,25 @@ def ocr_correction(prompt, max_new_tokens=600):
129
  splits = split_text(prompt, max_tokens=400)
130
  corrected_splits = []
131
 
 
 
132
  for split in splits:
133
  full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
 
134
  encoded = tokenizer.encode(full_prompt)
135
  prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
136
-
137
- result = generator.generate_batch(
138
- [prompt_tokens],
139
- max_length=max_new_tokens,
140
- sampling_temperature=0.7,
141
- sampling_topk=20,
142
- include_prompt_in_result=False
143
- )[0]
144
-
 
 
 
145
  corrected_text = tokenizer.decode(result.sequences_ids[0])
146
  corrected_splits.append(corrected_text)
147
 
 
129
  splits = split_text(prompt, max_tokens=400)
130
  corrected_splits = []
131
 
132
+ list_prompts = []
133
+
134
  for split in splits:
135
  full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
136
+ print(full_prompt)
137
  encoded = tokenizer.encode(full_prompt)
138
  prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
139
+ list_prompts.append(prompt_tokens)
140
+
141
+ results = generator.generate_batch(
142
+ list_prompts,
143
+ max_length=max_new_tokens,
144
+ sampling_temperature=0.7,
145
+ sampling_topk=20,
146
+ repetition_penalty=1.1,
147
+ include_prompt_in_result=False
148
+ )
149
+
150
+ for result in results:
151
  corrected_text = tokenizer.decode(result.sequences_ids[0])
152
  corrected_splits.append(corrected_text)
153