File size: 1,824 Bytes
966ee03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
from transformers import LEDForConditionalGeneration, LEDTokenizer
import torch
from datasets import load_dataset
import re

# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the LED model and tokenizer
model = LEDForConditionalGeneration.from_pretrained("./summary_generation_Led_4").to(device)
tokenizer = LEDTokenizer.from_pretrained("./summary_generation_Led_4")

# Normalize the input text (plot synopsis)
def normalize_text(text):
    text = text.lower()  # Lowercase the text
    text = re.sub(r'\s+', ' ', text).strip()  # Remove extra spaces and newlines
    text = re.sub(r'[^\w\s]', '', text)  # Remove non-alphanumeric characters
    return text

# Function to preprocess and generate summaries
def generate_summary(plot_synopsis):
    # Preprocess the plot_synopsis
    inputs = tokenizer("summarize: " + normalize_text(plot_synopsis), 
                       max_length=3000, truncation=True, padding="max_length", return_tensors="pt")
    inputs = inputs.to(device)

    # Generate the summary
    outputs = model.generate(inputs["input_ids"], max_length=315, min_length=20, 
                             length_penalty=2.0, num_beams=4, early_stopping=True)
    
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return summary

# Gradio interface to take plot synopsis and output a generated summary
interface = gr.Interface(
    fn=generate_summary,
    inputs=gr.Textbox(label="Plot Synopsis", lines=10, placeholder="Enter the plot synopsis here..."),
    outputs=gr.Textbox(label="Generated Summary"),
    title="Plot Summary Generator",
    description="This demo generates a plot summary based on the plot synopsis using a fine-tuned LED model."
)

# Launch the Gradio interface
interface.launch()