nougat-latex-base / README.md
Norm's picture
Update README.md
6889c7e
|
raw
history blame
3.21 kB
metadata
license: apache-2.0
language:
  - en

Nougat-LaTeX-based

Nougat-LaTeX-based is fine-tuned from facebook/nougat-base with im2latex-100k to boost its proficiency in generating LaTeX code from images. Since the initial encoder input image size of nougat was unsuitable for equation image segments, leading to potential rescaling artifacts that degrades the generation quality of LaTeX code. To address this, Nougat-LaTeX-based adjusts the input resolution to a height of 224 and a width of 560. Additionally, an adaptive padding approach is used to ensure that equation image segments in the wild are resized to closely match the resolution of the training data.

Evaluation

Evaluated on an image-equation pair dataset collected from Wikipedia, arXiv, and im2latex-100k, curated by lukas-blecher

model token_acc ↑ normed edit distance ↓
pix2tex* 0.60 0.10
nougat-latex-based 0.623850 0.06180
pix2tex*: reported from LaTeX-OCR; nougat-latex-based is evaluated on results generated with beam-search strategy.

Requirements

pip install transformers >= 4.34.0

Uses

import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast

from nougat_latex import NougatLaTexProcessor
from nougat_latex.image_processing_nougat import NougatImageProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
# init model
model = VisionEncoderDecoderModel.from_pretrained("Norm/nougat-latex-base").to(device)

# init processor
tokenizer = NougatTokenizerFast.from_pretrained("Norm/nougat-latex-base")

image_processor = NougatImageProcessor.from_pretrained("Norm/nougat-latex-base")
latex_processor = NougatLaTexProcessor(image_processor=image_processor)

# run test
image = Image.open("path/to/latex/image.png")
if not image.mode == "RGB":
    image = image.convert('RGB')

pixel_values = latex_processor(image)

decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
                              return_tensors="pt").input_ids
with torch.no_grad():
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_length,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
        num_beams=5,
        bad_words_ids=[[tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
print(sequence)