Razavipour's picture
Upload 11 files
966ee03 verified
raw
history blame
1.82 kB
import gradio as gr
from transformers import LEDForConditionalGeneration, LEDTokenizer
import torch
from datasets import load_dataset
import re
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the LED model and tokenizer
model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4").to(device)
tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4")
# Normalize the input text (plot synopsis)
def normalize_text(text):
text = text.lower() # Lowercase the text
text = re.sub(r'\s+', ' ', text).strip() # Remove extra spaces and newlines
text = re.sub(r'[^\w\s]', '', text) # Remove non-alphanumeric characters
return text
# Function to preprocess and generate summaries
def generate_summary(plot_synopsis):
# Preprocess the plot_synopsis
inputs = tokenizer("summarize: " + normalize_text(plot_synopsis),
max_length=3000, truncation=True, padding="max_length", return_tensors="pt")
inputs = inputs.to(device)
# Generate the summary
outputs = model.generate(inputs["input_ids"], max_length=315, min_length=20,
length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
return summary
# Gradio interface to take plot synopsis and output a generated summary
interface = gr.Interface(
fn=generate_summary,
inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter the plot synopsis here..."),
outputs=gr.Textbox(label="Generated Summary"),
title="Plot Summary Generator",
description="This demo generates a plot summary based on the plot synopsis using a fine-tuned LED model."
)
# Launch the Gradio interface
interface.launch()