{ "cells": [ { "cell_type": "markdown", "source": [ "
\n", " \n", "

Sharif University of Technology

\n", "

Natural Language Processing

\n", "

Final Project

\n", "

Spoiler classification and summary generation

\n", "

Authors: Parnian Razavipour, Mobina Salimipanah

\n", "

(Equal Contribution)

\n", "
\n", "
\n" ], "metadata": { "id": "3kZY57mYHNea" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VtuWHaKEQdEq", "outputId": "0aaf7811-598b-48ba-80d9-623db23d6f0e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.6.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", "Collecting datasets\n", " Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", "Collecting pyarrow>=15.0.0 (from datasets)\n", " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n", "Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.7)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Downloading datasets-3.0.1-py3-none-any.whl (471 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m471.6/471.6 kB\u001b[0m \u001b[31m15.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n", " Attempting uninstall: pyarrow\n", " Found existing installation: pyarrow 14.0.2\n", " Uninstalling pyarrow-14.0.2:\n", " Successfully uninstalled pyarrow-14.0.2\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed datasets-3.0.1 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.5.0\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" ] } ], "source": [ "!pip install transformers\n", "!pip install datasets\n", "!pip install torch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1vXYwybdRDD_", "outputId": "2a471522-0bd6-47a3-d924-4bf30899a6d7" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.0.1)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" ] } ], "source": [ "!pip install transformers datasets torch\n" ] }, { "cell_type": "markdown", "source": [ "##**Download and Load Dataset**" ], "metadata": { "id": "wJx5vLJFHlxT" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5tiNer2wNmKd", "outputId": "d76c300d-1387-4491-f2e7-426a8a375652" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.6.17)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.8.30)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.32.3)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.5)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.2.3)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.10)\n", "cp: cannot stat 'kaggle.json': No such file or directory\n", "chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory\n", "Dataset URL: https://www.kaggle.com/datasets/rmisra/imdb-spoiler-dataset\n", "License(s): Attribution 4.0 International (CC BY 4.0)\n", "Downloading imdb-spoiler-dataset.zip to /content\n", "100% 330M/331M [00:20<00:00, 20.6MB/s]\n", "100% 331M/331M [00:20<00:00, 16.7MB/s]\n", "Archive: imdb-spoiler-dataset.zip\n", " inflating: IMDB_movie_details.json \n", " inflating: IMDB_reviews.json \n", "IMDB_movie_details.json IMDB_reviews.json imdb-spoiler-dataset.zip sample_data\n" ] } ], "source": [ "!pip install kaggle\n", "\n", "!mkdir -p ~/.kaggle\n", "!cp kaggle.json ~/.kaggle/\n", "\n", "!chmod 600 ~/.kaggle/kaggle.json\n", "\n", "!kaggle datasets download -d rmisra/imdb-spoiler-dataset\n", "\n", "!unzip imdb-spoiler-dataset.zip\n", "\n", "!ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uXLdvqgVOL99", "outputId": "658bd35a-2208-4869-b6f5-ea22d6eece6d" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " movie_id plot_summary duration \\\n", "0 tt0105112 Former CIA analyst, Jack Ryan is in England wi... 1h 57min \n", "1 tt1204975 Billy (Michael Douglas), Paddy (Robert De Niro... 1h 45min \n", "2 tt0243655 The setting is Camp Firewood, the year 1981. I... 1h 37min \n", "3 tt0040897 Fred C. Dobbs and Bob Curtin, both down on the... 2h 6min \n", "4 tt0126886 Tracy Flick is running unopposed for this year... 1h 43min \n", "\n", " genre rating release_date \\\n", "0 [Action, Thriller] 6.9 1992-06-05 \n", "1 [Comedy] 6.6 2013-11-01 \n", "2 [Comedy, Romance] 6.7 2002-04-11 \n", "3 [Adventure, Drama, Western] 8.3 1948-01-24 \n", "4 [Comedy, Drama, Romance] 7.3 1999-05-07 \n", "\n", " plot_synopsis \n", "0 Jack Ryan (Ford) is on a \"working vacation\" in... \n", "1 Four boys around the age of 10 are friends in ... \n", "2 \n", "3 Fred Dobbs (Humphrey Bogart) and Bob Curtin (T... \n", "4 Jim McAllister (Matthew Broderick) is a much-a... \n", " review_date movie_id user_id is_spoiler \\\n", "0 10 February 2006 tt0111161 ur1898687 True \n", "1 6 September 2000 tt0111161 ur0842118 True \n", "2 3 August 2001 tt0111161 ur1285640 True \n", "3 1 September 2002 tt0111161 ur1003471 True \n", "4 20 May 2004 tt0111161 ur0226855 True \n", "\n", " review_text rating \\\n", "0 In its Oscar year, Shawshank Redemption (writt... 10 \n", "1 The Shawshank Redemption is without a doubt on... 10 \n", "2 I believe that this film is the best story eve... 8 \n", "3 **Yes, there are SPOILERS here**This film has ... 10 \n", "4 At the heart of this extraordinary movie is a ... 8 \n", "\n", " review_summary \n", "0 A classic piece of unforgettable film-making. \n", "1 Simply amazing. The best film of the 90's. \n", "2 The best story ever told on film \n", "3 Busy dying or busy living? \n", "4 Great story, wondrously told and acted \n" ] } ], "source": [ "import pandas as pd\n", "\n", "movie_details = pd.read_json('IMDB_movie_details.json', lines=True)\n", "reviews = pd.read_json('IMDB_reviews.json', lines=True)\n", "print(movie_details.head())\n", "print(reviews.head())\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "44kpxJ0iOjLz" }, "outputs": [], "source": [ "# Drop rows where 'plot_synopsis' or 'plot_summary' is missing\n", "movie_details.dropna(subset=['plot_synopsis', 'plot_summary'], inplace=True)\n" ] }, { "cell_type": "markdown", "source": [ "##**training and test sets**" ], "metadata": { "id": "2OUyhuoiH_QV" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qaAPgonAPRGx" }, "outputs": [], "source": [ "\n", "# Split the data into training and test sets\n", "from sklearn.model_selection import train_test_split\n", "\n", "train_data, temp_data = train_test_split(movie_details, test_size=0.3, random_state=42) # 70% for training, 30% for temp\n", "\n", "validation_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42) # Split the 30% into 15% validation and 15% test\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gewp_maVVU6g", "outputId": "884c7f25-053a-4106-feba-960386884107" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(1100, 7)" ] }, "metadata": {}, "execution_count": 7 } ], "source": [ "train_data.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6Y-syQ3ZPVDR", "outputId": "85ed66a3-8bb4-4d33-b394-8c293b55ffdd" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "count 1572.000000\n", "mean 1439.085242\n", "std 1496.392929\n", "min 0.000000\n", "25% 493.750000\n", "50% 1073.500000\n", "75% 1920.000000\n", "max 11396.000000\n", "Name: synopsis_length, dtype: float64\n" ] } ], "source": [ "# Example exploration: Average length of plot_synopsis\n", "movie_details['synopsis_length'] = movie_details['plot_synopsis'].apply(lambda x: len(x.split()))\n", "print(movie_details['synopsis_length'].describe())\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6lRJv7A7PX8x" }, "outputs": [], "source": [ "train_data.to_json('/content/train_data.json', orient='records', lines=True)\n", "validation_data.to_json('/content/val_data.json', orient='records', lines=True)\n", "test_data.to_json('/content/test_data.json', orient='records', lines=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "referenced_widgets": [ "f98bee5f48254d9eabd10431adf5aa9b", "6da7932ca9d749ce860e69c2a38c2541", "abf3f9738aa944f1b1cd615b0e3f4a25", "dfbc0cc3e4624d3d9cb25f063548d43a", "4eccac116b3c4175a328b28ed4d63e13", "8c6f5315e19c42369cec6cab6af10f92", "8e681408f4574d7c96f6403475c82a6f", "0ed0e435a6494dc896534036f84513bd", "bcb2b97307814249b005f05b74edcee9", "76e9bb11a1e64b02a191b6c415a90c41", "fafce82ad5504cf5ab6f49180ee12522", "2e568c6632ee43819924eeeb7c8c5739", "acd1d2b5fa534f72968f02110462b118", "0ece36fed4174eb198b1b9882c67fd58", "fe8439b6f5bc4b228d2a6a446614a0cb", "83628ad7a12242778b8bc106fad8c40c", "320973fd444141d7b78fcef4c03a174e", "828a7ff0740a482fb0034fa8d92fee81", "fbe47b01fead44c099e12920dc8efe78", "9085fbc279254b35a1e3b0e95bfb9ac9", "b59a14bf014b4365a9dbd16da1b0ca3e", "0411f712b615428fa6377aa47eeb8f12", "5870843299214a8988fd4caaaf61ab28", "7e1bb2578c2e4aa7a0c6b5fcdb997e6e", "03ded09b8b5a4ebb869aeb65c87aa10c", "130bc93a5870428ea6ee440cb6e21830", "fca1ce2fe9a6497d903866a566c85a9c", "537e0076644145b889196d291bde4374", "5e80aefc2bb9415cb9c25c5e5219e38d", "3334c22eeadd426085be9ae64b34aa46", "b5574fe8a4cb47fca826b5dae823eb11", "c8409f30e0bb42f892ec697783caaeeb", "ca114842ea594e8291bd0a58e4652b0c" ] }, "id": "yj5INXEXPhsr", "outputId": "97f90268-e839-4106-9306-a221c4ca1b22" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "f98bee5f48254d9eabd10431adf5aa9b" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating test split: 0 examples [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "2e568c6632ee43819924eeeb7c8c5739" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating val split: 0 examples [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "5870843299214a8988fd4caaaf61ab28" } }, "metadata": {} } ], "source": [ "from datasets import load_dataset\n", "\n", "data_files = {\n", " 'train': '/content/train_data.json',\n", " 'test': '/content/test_data.json',\n", " 'val': '/content/val_data.json'\n", "\n", "\n", "}\n", "\n", "# Loading the dataset from the JSON files\n", "dataset = load_dataset('json', data_files=data_files, split={'train': 'train', 'test': 'test', 'val': 'val'})\n" ] }, { "cell_type": "markdown", "source": [ "##**Exploratory Data Analysis (EDA)**" ], "metadata": { "id": "mPWl8kAFIjoG" } }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 411, "referenced_widgets": [ "40415712030a409194c8e5d8016ac438", "2ff91beaccf84d198af6dd09fdd45a84", "91281d1677d54499a36a5b52c7a5deb3", "b3c29ef77d5f4a609209963ddeeb957c", "3a89f17eb147489e95b020b719b5e9b5", "b037f88d1faa4fedae99f99b2934b99c", "33d23975450440e885275a946a62fcd5", "14ddf24706744ae99151e1ef2ec50184", "fc9528e659914c2abed75db51e6dbc46", "8578b9081c104ae09f76c7941c8d55ac", "8f995c8a721841059f49a7751702a264", "48e0e360e8bd496daf07b9f41f62ab4c", "4655a7b4746d4dbb94f11bc4bf649cd2", "824f9d74bd9043d69c5fbd7b8670f12b", "c270c9de8d1e4f1ba69fc2d11f41c5da", "8a2f0bf77c5244db861a91df114fbfe4", "7967f8faea82454782f928f5953171aa", "1e252e01d9d24f928370ec596c1d2dce", "0a2b31a22ea64959b4be31046c927710", "d9cbfca658cd42c6bf4b272ae586212d", "9ef8c17ef4b9440db46a2fa02f7366c5", "c761198e022d4c32a7d1b79dc04967fa", "33bc1b435a9e42b7a8e40538e4f4f670", "ceeaf991d8c44123b758e290bfcdd9d1", "46de1b040239439aaa1819972a5f854a", "f8f00b01f89b46f4841c4b7b72adf7e4", "ab495b67bf124b9da726fc12c68aace9", "d2446351d77f4b27a06de442f5186a5c", "52735ac507c04d2bbaa4c60461cfeef3", "553fe0d427154df38dedf9c93ecfed0e", "6decf9a706a14e7b8df71e6bcab55a06", "1e0cc3e5eab8408aa2506edf11daa581", "9ff6312e60c84aca9e8f3d1478f6b665" ] }, "id": "GOe61CxiYkgv", "outputId": "b5f95681-6609-4620-e64e-7b0309f2f14c" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "tokenizer_config.json: 0%| | 0.00/2.32k [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "count 1572.000000\n", "mean 2116.767812\n", "std 2197.103130\n", "min 0.000000\n", "25% 729.000000\n", "50% 1581.000000\n", "75% 2814.750000\n", "max 18103.000000\n", "Name: token_length, dtype: float64\n" ] } ], "source": [ "from transformers import T5Tokenizer\n", "\n", "\n", "import pandas as pd\n", "\n", "movie_details = pd.read_json('/content/IMDB_movie_details.json', lines=True)\n", "\n", "tokenizer = T5Tokenizer.from_pretrained('t5-small')\n", "\n", "def calculate_token_length(text):\n", " return len(tokenizer.tokenize(text))\n", "\n", "movie_details['token_length'] = movie_details['plot_synopsis'].apply(calculate_token_length)\n", "\n", "stats = movie_details['token_length'].describe()\n", "print(stats)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QHSO1Tz4Z5kc", "outputId": "a13a07fa-69db-4429-97ff-6a9231f4e80e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "count 1572.000000\n", "mean 151.191476\n", "std 60.718672\n", "min 20.000000\n", "25% 103.000000\n", "50% 142.000000\n", "75% 195.250000\n", "max 315.000000\n", "Name: token_length, dtype: float64\n" ] } ], "source": [ "# Apply the function to the plot_synopsis column\n", "movie_details['token_length'] = movie_details['plot_summary'].apply(calculate_token_length)\n", "\n", "# Display statistics about the token lengths\n", "stats = movie_details['token_length'].describe()\n", "print(stats)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VS4ROiIKnvLy" }, "outputs": [], "source": [ "device = 'cuda'\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dTcbyQKVQFFA" }, "source": [ "##**Preprocess the Data**" ] }, { "cell_type": "code", "source": [ "import re\n", "import torch\n", "from transformers import LEDForConditionalGeneration, LEDTokenizer\n", "\n", "# Load tokenizer and model\n", "tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')\n", "model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')\n", "\n", "model = model.to(device)\n", "\n", "# Function to normalize text\n", "def normalize_text(text):\n", " text = text.lower() # Lowercase the text\n", " text = re.sub(r'\\s+', ' ', text).strip() # Remove extra spaces and newlines\n", " text = re.sub(r'[^\\w\\s]', '', text) # Remove non-alphanumeric characters\n", " return text\n", "\n", "# Preprocess function with normalization\n", "def preprocess_function(examples):\n", " # Normalize the plot_synopsis and plot_summary\n", " inputs = [\"summarize: \" + normalize_text(doc) for doc in examples[\"plot_synopsis\"]]\n", " model_inputs = tokenizer(inputs, max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", "\n", " # Normalize labels (plot_summary)\n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer([normalize_text(doc) for doc in examples[\"plot_summary\"]], max_length=1024, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", "\n", " # Replace -100 for padding tokens in labels\n", " labels[\"input_ids\"] = [\n", " [(label if label != tokenizer.pad_token_id else -100) for label in lab]\n", " for lab in labels[\"input_ids\"]\n", " ]\n", "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs" ], "metadata": { "id": "eBxLkLaUHJLI", "colab": { "base_uri": "https://localhost:8080/", "height": 296, "referenced_widgets": [ "4f57295463b143ec981c5b1d3a8e5a49", "30bb9b323bb34b25b7bd94a57a08f947", "f8f5ba727c104114b1aacbaf18be3b92", "d787d652d4044ccb9529144d49e10c43", "05fb81fc15954e0da2be232105a5fe34", "06c3259f4c334ba19f8f3ec16c94dc71", "ff91264eb09e470dbfab792e011b7c9d", "66b831532b4e407687b49409b3f87767", "3ef18fd318d04fb8b4c91b42707b4bd2", "2885b8cfc2b741209c0caa0a8c9fa4c7", "72f23b4514bf42b0a2475b98dfb7d7a7", "daebf1aac7f54e05a22954dfc9ac8475", "3502ffed2d904193a7f5cfbf9e46f941", "052f3cd1f15d43d684b3f4bf1b315c46", "b6000f50c0ae4023960afe1f54463203", "58bfcde727374cb5b27e0eea0eae6579", "cff8944496a9449a9d1e9269cbf6534f", "c259b071ad53489d9dfa69283afc30e4", "6820b5075d70492c8137a998c089ac05", "0e6224d4897345e3be15d07e0a672de1", "501ffe823844428085168dbb2adc9df8", "af5d02cdb35a4061b8b45562c1fdc11e", "63cc6928f2324ced805d8e94a94e16dc", "f9578e0721f347d7aa26fbbb41c9e50a", "08bc210b294f46718f9e2b05470c905d", "93f990215b9645ff93075b963bcb012d", "155b5d91def24487aab1cbd31112965a", "0d57abcd6dcd41aeb92cdd86615f181a", "4e782e5b17eb46d38b15c46de5260999", "a5a7ad1bbad34587b312aa140ca09bb9", "647ccaad654543f59d77d4a3ad789631", "03f48376b86b4c2bbe698dbe6f908ac1", "3e410bc98f834174ae3108a4b6eb6d0d", "fd0f7478cf4946c3bdf5b597dc06b9a6", "3c238acc24aa468fb18cddffe70beabf", "54a06f05f47743d4ba4c030f7971163f", "d18b831bcf614445a4988b98f190642e", "d51f1868f20846cea882f30b799df9cc", "1e8757c3bc244d6990c1e57f235bfa83", "1622081b48d44e639c1e632724cb68e6", "53257d25bb7d4daa92271af9a188206d", "5d33d79ee7c14218a93dbea6a0ace598", "8ab30d4ac61a45ca9731931080f1e5af", "328440cf8dbf4f0396210c35b194d065", "dda889cd884341e99555a04ca260c8d6", "8308cf1b24a34c7ba35f84b1c23e2d9d", "28b07b90d6494cccb9d77a37477e26d3", "9a751a2507df4d2d938ffb7a0324a28f", "23c64079b95245bba6b96d89bc73ec64", "09c54c8238764d2fbbea9df2c60c750c", "f273c230aea34e9eab790ac98dcf56a7", "a86891ddccfa469288f12143c1630305", "a5fea7e1c24448ba8de4db809cb8c535", "fb7b6579ec8d41c4880225d75724e370", "2a63c7005a504f068d143d9f7a80940d", "03e69634d3364440b466a6135df1b6c5", "fbef44cbcaa84d738e261c24d9eefd9d", "1e3c017585d54afe96e0e2d03f7155e0", "da966352362e45e88a18a8b04d4590d7", "591e88da46b540e88b010d558e2e027b", "2ea129aeb4ad4bae96bb490fc217eb2a", "b28bb46030194be58630881f69524b9f", "b7768d74494b47e2acfa6e153544783a", "f4d5e8921d244f9dba213914b70b589e", "ce38c13bfcff4214822873a92817004e", "c6eec72b32b8426d959a704038c2312b", "aff4182f85c44cd3828fd2e35321a5f1", "bce088936ddf4e8abfe78e1191e96ff0", "2ebd0b5ab9bd4c5a8c6e642c411f065d", "5b9cd84929ed49b7b304bc5712c28a6f", "072446b10e1045589106f92fbd95eafc", "376ffd83dbbc4b23b71e838fe0713b94", "16f244a4438a439f8df8803fb584cfab", "106b6d087e5d4377a0a16b811f007015", "5bf46c40aea848db8dde59773bdd53fa", "0bd2fd1f6fc84f7e81291f558832706b", "12489aa45cfd4f7ba0c471b8f5a09d58" ] }, "outputId": "8b840812-4822-433b-d36d-b8430fd33e5f" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "tokenizer_config.json: 0%| | 0.00/27.0 [00:00\n", " \n", " \n", " [2001/3300 47:25 < 30:49, 0.70 it/s, Epoch 1.82/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
12.9253002.873685

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [3300/3300 1:23:26, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
12.9253002.873685
22.4096002.883492
32.1653002.920285

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=3300, training_loss=2.429696747750947, metrics={'train_runtime': 5007.437, 'train_samples_per_second': 0.659, 'train_steps_per_second': 0.659, 'total_flos': 6526373990400000.0, 'train_loss': 2.429696747750947, 'epoch': 3.0})" ] }, "metadata": {}, "execution_count": 21 } ] }, { "cell_type": "markdown", "source": [ "##**Save the Model (3 epochs trained)**" ], "metadata": { "id": "uBp1rPWNI8l9" } }, { "cell_type": "code", "source": [ "model_save_path = '/content/drive/MyDrive/summary_generation_Led_3'\n", "trainer.save_model(model_save_path)" ], "metadata": { "id": "kFuNP98wGaa2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')\n", "\n", "# Load the trained model from Google Drive\n", "model_save_path = '/content/drive/MyDrive/summary_generation_Led_3'\n", "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", "\n", "model = model.to(device)\n" ], "metadata": { "id": "yopwoSEJgNkT" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "##**Continue training for 1 additional epoch**" ], "metadata": { "id": "T8XIu4hFJSQk" } }, { "cell_type": "code", "source": [ "from transformers import Trainer, TrainingArguments\n", "\n", "# Reload the model from the saved checkpoint in Google Drive\n", "model_save_path = '/content/drive/MyDrive/summary_generation' # Your saved path\n", "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", "model = model.to(device) # Ensure the model is on the right device (GPU or CPU)\n", "\n", "# Define the training arguments for continuing training\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\",\n", " eval_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=1, # LED is memory-intensive\n", " per_device_eval_batch_size=1,\n", " weight_decay=0.01,\n", " save_total_limit=3, # Only keep the last 3 models saved\n", " num_train_epochs=1, # Continue training for 1 additional epoch\n", " report_to=\"none\",\n", " logging_dir='./logs', # Directory for storing logs\n", " logging_steps=500,\n", " load_best_model_at_end=True,\n", "\n", "# Initialize the optimizer (AdamW) again for the new training run\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"val\"],\n", " tokenizer=tokenizer,\n", " optimizers=(optimizer, None)\n", ")\n", "\n", "trainer.train()\n", "\n", "trainer.save_model(\"/content/drive/MyDrive/summary_generation_Led_4\") # Save the new checkpoint\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 144 }, "id": "DnDXbarJm1YG", "outputId": "90c5c5a8-bf06-4631-e868-17c6e7dc2f4a" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [1100/1100 26:32, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
12.1405002.964694

" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "There were missing keys in the checkpoint model loaded: ['led.encoder.embed_tokens.weight', 'led.decoder.embed_tokens.weight', 'lm_head.weight'].\n" ] } ] }, { "cell_type": "markdown", "source": [ "##**Evaluation on test data**" ], "metadata": { "id": "ewfMVjdEJoWF" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7S7rMmsRzUDK" }, "outputs": [], "source": [ "pip install rouge_score" ] }, { "cell_type": "code", "source": [ "pip install nltk\n" ], "metadata": { "id": "OFJmH2KUtdod" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import torch, gc\n", "\n", "gc.collect()\n", "torch.cuda.empty_cache()\n" ], "metadata": { "id": "SEGn9IQ4to0c" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import nltk\n", "\n", "nltk.download('punkt')\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FliiA5j_-bK6", "outputId": "a1784ce0-e243-4f06-b2dc-554300ec8726" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "[nltk_data] Downloading package punkt to /root/nltk_data...\n", "[nltk_data] Unzipping tokenizers/punkt.zip.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 20 } ] }, { "cell_type": "code", "source": [ "from datasets import load_metric\n", "import nltk\n", "\n", "metric_rouge = load_metric(\"rouge\")\n", "\n", "def generate_summary(batch):\n", " inputs = tokenizer(batch[\"plot_synopsis\"], max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", " inputs = inputs.to(device)\n", " outputs = model.generate(inputs[\"input_ids\"], max_length=315, min_length=20, length_penalty=2.0, num_beams=4, early_stopping=True)\n", "\n", " batch[\"pred_summary\"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", " return batch\n", "\n", "results = dataset[\"test\"].map(generate_summary, batched=True, batch_size=8)\n", "rouge_score = metric_rouge.compute(predictions=results[\"pred_summary\"], references=results[\"plot_summary\"])\n", "print(\"ROUGE scores:\")\n", "print(rouge_score)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 104, "referenced_widgets": [ "ad570df7408d48d99eb69a50ad8f3201", "265654d5426b4616b530e7fd1c7cec04", "ccb396e3506e487284decd31875fa33d", "2873376fa2264b26ab8dcebe41d46af0", "13b89ba59e80499f9d6926ee97d9f519", "edf3f110599943ab822847b15bcfd0d4", "8d275ca50cce4c48889ddc50769ada21", "d89b8b1072d9408c96d2f7b2e735c785", "0e6ac91d7b284b9b82abdedf36e91924", "e0b2c384ca364c49869ae498683a6ef3", "a84982a5e3ff40a18ab9b416cd85a111" ] }, "id": "I1wd38l3sYAy", "outputId": "4a6e0d20-74b4-4916-edf3-c7b0ef4b5117" }, "execution_count": null, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ad570df7408d48d99eb69a50ad8f3201", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/236 [00:00