from datasets import Dataset from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments from youtube_transcript_api import YouTubeTranscriptApi from deepmultilingualpunctuation import PunctuationModel from googletrans import Translator import time import torch import re def load_model(cp): tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base") model = AutoModelForSeq2SeqLM.from_pretrained(cp) return tokenizer, model def summarize(text, model, tokenizer, num_beams=4, device='cpu'): model.to(device) inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device) with torch.no_grad(): summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary def processed(text): processed_text = text.replace('\n', ' ') processed_text = processed_text.lower() return processed_text def get_subtitles(video_url): try: video_id = video_url.split("v=")[1] transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en']) subs = " ".join(entry['text'] for entry in transcript) print(subs) return transcript, subs except Exception as e: return [], f"An error occurred: {e}" from youtube_transcript_api import YouTubeTranscriptApi def restore_punctuation(text): model = PunctuationModel() result = model.restore_punctuation(text) return result def translate_long(text, language='vi'): translator = Translator() limit = 4700 chunks = [] current_chunk = '' sentences = re.split(r'(?= overlap_sentences: overlap = current_chunk[-overlap_sentences:] chunks.append(' '.join(current_chunk)) current_chunk = current_chunk[-overlap_sentences:] + [sentence] current_word_count = sum(len(sent.split()) for sent in current_chunk) if current_chunk: if len(current_chunk) >= overlap_sentences: overlap = current_chunk[-overlap_sentences:] chunks.append(' '.join(current_chunk)) return chunks def post_processing(text): sentences = re.split(r'(?<=[.!?])\s*', text) for i in range(len(sentences)): if sentences[i]: sentences[i] = sentences[i][0].upper() + sentences[i][1:] text = " ".join(sentences) return text def display(text): sentences = re.split(r'(?<=[.!?])\s*', text) unique_sentences = list(dict.fromkeys(sentences[:-1])) formatted_sentences = [f"• {sentence}" for sentence in unique_sentences] return formatted_sentences def pipeline(url, model, tokenizer): trans, sub = get_subtitles(url) sub = restore_punctuation(sub) vie_sub = translate_long(sub) vie_sub = processed(vie_sub) chunks = split_into_chunks(vie_sub, 700, 2) sum_para = [] for i in chunks: tmp = summarize(i, model, tokenizer, num_beams=3) sum_para.append(tmp) suma = ''.join(sum_para) del sub, vie_sub, sum_para, chunks suma = post_processing(suma) re = display(suma) return re def update(name): return f"Welcome to Gradio, {name}!"