|
import gradio as gr |
|
from transformers import LEDForConditionalGeneration, LEDTokenizer |
|
import torch |
|
from datasets import load_dataset |
|
import re |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4").to(device) |
|
tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4") |
|
|
|
|
|
def normalize_text(text): |
|
text = text.lower() |
|
text = re.sub(r'\s+', ' ', text).strip() |
|
text = re.sub(r'[^\w\s]', '', text) |
|
return text |
|
|
|
|
|
def generate_summary(plot_synopsis): |
|
|
|
inputs = tokenizer("summarize: " + normalize_text(plot_synopsis), |
|
max_length=3000, truncation=True, padding="max_length", return_tensors="pt") |
|
inputs = inputs.to(device) |
|
|
|
|
|
outputs = model.generate(inputs["input_ids"], max_length=315, min_length=20, |
|
length_penalty=2.0, num_beams=4, early_stopping=True) |
|
|
|
summary = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return summary |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_summary, |
|
inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter the plot synopsis here..."), |
|
outputs=gr.Textbox(label="Generated Summary"), |
|
title="Plot Summary Generator", |
|
description="This demo generates a plot summary based on the plot synopsis using a fine-tuned LED model." |
|
) |
|
|
|
|
|
interface.launch() |
|
|