|
import gradio as gr
|
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
import torch
|
|
|
|
|
|
model_name = "yashvoladoddi37/movie-title-OCR-corrector-t5"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
|
|
|
|
def correct_text(input_text):
|
|
inputs = tokenizer(input_text, return_tensors="pt", padding=True).to(device)
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
input_ids=inputs["input_ids"],
|
|
attention_mask=inputs["attention_mask"],
|
|
max_length=512
|
|
)
|
|
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
return corrected_text
|
|
|
|
iface = gr.Interface(
|
|
fn=correct_text,
|
|
inputs=gr.Textbox(lines=2, placeholder="Enter text to correct"),
|
|
outputs="text",
|
|
title="OCR Correction Demo",
|
|
description="Enter text with OCR errors, and the model will attempt to correct them."
|
|
)
|
|
|
|
iface.launch() |