Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,82 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
---
|
6 |
+
|
7 |
+
# Nougat-LaTeX-based
|
8 |
+
|
9 |
+
- **Model type:** [Donut](https://huggingface.co/docs/transformers/model_doc/donut)
|
10 |
+
- **Finetuned from:** [facebook/nougat-base](https://huggingface.co/facebook/nougat-base)
|
11 |
+
- **Repository:** [source code](https://github.com/NormXU/nougat-latext-ocr)
|
12 |
+
|
13 |
+
Nougat-LaTeX-based is fine-tuned from [facebook/nougat-base](https://huggingface.co/facebook/nougat-base) with [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) to boost its proficiency in generating LaTeX code from images.
|
14 |
+
Since the initial encoder input image size of nougat was unsuitable for equation image segments, leading to potential rescaling artifacts that degrades the generation quality of LaTeX code. To address this, Nougat-LaTeX-based adjusts the input resolution to a height of 224 and a width of 560.
|
15 |
+
Additionally, an adaptive padding approach is used to ensure that equation image segments in the wild are resized to closely match the resolution of the training data.
|
16 |
+
|
17 |
+
|
18 |
+
### Evaluation
|
19 |
+
Evaluated on an image-equation pair dataset collected from Wikipedia, arXiv, and im2latex-100k, curated by [lukas-blecher](https://github.com/lukas-blecher/LaTeX-OCR#data)
|
20 |
+
|
21 |
+
|model| token_acc ↑ | normed edit distance ↓ |
|
22 |
+
| --- | --- | --- |
|
23 |
+
|pix2tex*|0.60|0.10|
|
24 |
+
|nougat-latex-based| **0.623850** | **0.06180** |
|
25 |
+
pix2tex*: reported from [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR); nougat-latex-based is evaluated on results generated with beam-search strategy.
|
26 |
+
|
27 |
+
## Requirements
|
28 |
+
```text
|
29 |
+
pip install transformers >= 4.34.0
|
30 |
+
```
|
31 |
+
|
32 |
+
## Uses
|
33 |
+
```python
|
34 |
+
import torch
|
35 |
+
from PIL import Image
|
36 |
+
from transformers import VisionEncoderDecoderModel
|
37 |
+
from transformers.models.nougat import NougatTokenizerFast
|
38 |
+
|
39 |
+
from nougat_latex import NougatLaTexProcessor
|
40 |
+
from nougat_latex.image_processing_nougat import NougatImageProcessor
|
41 |
+
|
42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
# init model
|
44 |
+
model = VisionEncoderDecoderModel.from_pretrained("Norm/nougat-latex-base").to(device)
|
45 |
+
|
46 |
+
# init processor
|
47 |
+
tokenizer = NougatTokenizerFast.from_pretrained("Norm/nougat-latex-base")
|
48 |
+
|
49 |
+
image_processor = NougatImageProcessor.from_pretrained("Norm/nougat-latex-base")
|
50 |
+
latex_processor = NougatLaTexProcessor(image_processor=image_processor)
|
51 |
+
|
52 |
+
# run test
|
53 |
+
image = Image.open("path/to/latex/image.png")
|
54 |
+
if not image.mode == "RGB":
|
55 |
+
image = image.convert('RGB')
|
56 |
+
|
57 |
+
pixel_values = latex_processor(image)
|
58 |
+
|
59 |
+
decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
|
60 |
+
return_tensors="pt").input_ids
|
61 |
+
with torch.no_grad():
|
62 |
+
outputs = model.generate(
|
63 |
+
pixel_values.to(device),
|
64 |
+
decoder_input_ids=decoder_input_ids.to(device),
|
65 |
+
max_length=model.decoder.config.max_length,
|
66 |
+
early_stopping=True,
|
67 |
+
pad_token_id=tokenizer.pad_token_id,
|
68 |
+
eos_token_id=tokenizer.eos_token_id,
|
69 |
+
use_cache=True,
|
70 |
+
num_beams=5,
|
71 |
+
bad_words_ids=[[tokenizer.unk_token_id]],
|
72 |
+
return_dict_in_generate=True,
|
73 |
+
)
|
74 |
+
sequence = tokenizer.batch_decode(outputs.sequences)[0]
|
75 |
+
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
|
76 |
+
print(sequence)
|
77 |
+
|
78 |
+
```
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|