import spaces import gradio as gr from huggingface_hub import list_models from typing import List import torch from transformers import DonutProcessor, VisionEncoderDecoderModel from PIL import Image import json import re import logging from datasets import load_dataset import os import numpy as np from datetime import datetime # Importar utils y save_img si no están ya importados import utils # Logging configuration logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Paths to the static image and GIF README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png") GIF_PATH = os.path.join("figs", "demo-samples.gif") # Global variables for Donut model, processor, and dataset dataset = None def load_merit_dataset(): global dataset if dataset is None: dataset = load_dataset( "de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8 ) return dataset def get_image_from_dataset(index): global dataset if dataset is None: dataset = load_merit_dataset() image_data = dataset[int(index)]["image"] return image_data def get_collection_models(tag: str) -> List[str]: """Get a list of models from a specific Hugging Face collection.""" models = list_models(author="de-Rodrigo") return [model.modelId for model in models if tag in model.tags] def initialize_donut(): try: donut_model = VisionEncoderDecoderModel.from_pretrained( "de-Rodrigo/donut-merit" ) donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit") donut_model = donut_model.to("cuda") logger.info("Donut model loaded successfully on GPU") return donut_model, donut_processor except Exception as e: logger.error(f"Error loading Donut model: {str(e)}") raise def compute_saliency(outputs, pixels, donut_p, image): token_logits = torch.stack(outputs.scores, dim=1) token_probs = torch.softmax(token_logits, dim=-1) token_texts = [] saliency_images = [] for token_index in range(len(token_probs[0])): target_token_prob = token_probs[ 0, token_index, outputs.sequences[0, token_index] ] if pixels.grad is not None: pixels.grad.zero_() target_token_prob.backward(retain_graph=True) saliency = pixels.grad.data.abs().squeeze().mean(dim=0) token_id = outputs.sequences[0][token_index].item() token_text = donut_p.tokenizer.decode([token_id]) logger.info(f"Considered sequence token: {token_text}") safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text) current_datetime = datetime.now().strftime("%Y%m%d%H%M%S") unique_safe_token_text = f"{safe_token_text}_{current_datetime}" file_name = f"saliency_{unique_safe_token_text}.png" saliency = utils.convert_tensor_to_rgba_image(saliency) # Merge saliency image twice saliency = utils.add_transparent_image(np.array(image), saliency) saliency = utils.convert_rgb_to_rgba_image(saliency) saliency = utils.add_transparent_image(np.array(image), saliency, 0.7) saliency = utils.label_frame(saliency, token_text) saliency_images.append(saliency) token_texts.append(token_text) return saliency_images, token_texts @spaces.GPU(duration=300) def process_image_donut(image): try: model, processor = initialize_donut() if not isinstance(image, Image.Image): image = Image.fromarray(image) pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda") pixel_values.requires_grad = True task_prompt = "" decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" )["input_ids"].to("cuda") outputs = model.generate.__wrapped__( model, pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, output_scores=True, ) saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "" ) sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() result = processor.token2json(sequence) return saliency_images, json.dumps(result, indent=2) except Exception as e: logger.error(f"Error processing image with Donut: {str(e)}") return None, f"Error: {str(e)}" @spaces.GPU(duration=300) def process_image(model_name, image=None, dataset_image_index=None): if dataset_image_index is not None: image = get_image_from_dataset(dataset_image_index) if model_name == "de-Rodrigo/donut-merit": saliency_images, result = process_image_donut(image) else: # Aquí deberías implementar el procesamiento para otros modelos saliency_images, result = None, f"Processing for model {model_name} not implemented" return saliency_images, result def update_image(dataset_image_index): return get_image_from_dataset(dataset_image_index) if __name__ == "__main__": # Load the dataset load_merit_dataset() models = get_collection_models("saliency") models.append("de-Rodrigo/donut-merit") with gr.Blocks() as demo: gr.Markdown("# Saliency Maps with the MERIT Dataset 🎒📃🏆") with gr.Row(): with gr.Column(scale=1): gr.Image(value=README_IMAGE_PATH, height=400) with gr.Column(scale=1): gr.Image( value=GIF_PATH, label="Dataset samples you can process", height=400 ) with gr.Tab("Introduction"): gr.Markdown( """ ## Welcome to Saliency Maps with the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) 🎒📃🏆 This space demonstrates the capabilities of different Vision Language models for document understanding tasks. ### Key Features: - Process images from the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) or upload your own image. - Use a fine-tuned version of the models availabe to extract grades from documents. - Visualize saliency maps to understand where the model is looking (WIP 🛠️). """ ) with gr.Tab("Try It Yourself"): gr.Markdown( "Select a model and an image from the dataset, or upload your own image." ) with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown(choices=models, label="Select Model") dataset_slider = gr.Slider( minimum=0, maximum=len(dataset) - 1, step=1, label="Dataset Image Index", ) upload_image = gr.Image( type="pil", label="Or Upload Your Own Image" ) preview_image = gr.Image(label="Selected/Uploaded Image") process_button = gr.Button("Process Image") with gr.Row(): output_image = gr.Gallery(label="Processed Saliency Images") output_text = gr.Textbox(label="Result") # Update preview image when slider changes dataset_slider.change( fn=update_image, inputs=[dataset_slider], outputs=[preview_image] ) # Update preview image when an image is uploaded upload_image.change( fn=lambda x: x, inputs=[upload_image], outputs=[preview_image] ) # Process image when button is clicked process_button.click( fn=process_image, inputs=[model_dropdown, upload_image, dataset_slider], outputs=[output_image, output_text], ) demo.launch()