diff --git "a/final_LED_model.ipynb" "b/final_LED_model.ipynb" new file mode 100644--- /dev/null +++ "b/final_LED_model.ipynb" @@ -0,0 +1,7628 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "
\n", + " \n", + "

Sharif University of Technology

\n", + "

Natural Language Processing

\n", + "

Final Project

\n", + "

Spoiler classification and summary generation

\n", + "

Authors: Parnian Razavipour, Mobina Salimipanah

\n", + "

(Equal Contribution)

\n", + "
\n", + "
\n" + ], + "metadata": { + "id": "3kZY57mYHNea" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VtuWHaKEQdEq", + "outputId": "0aaf7811-598b-48ba-80d9-623db23d6f0e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.6.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", + "Collecting datasets\n", + " Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", + "Collecting pyarrow>=15.0.0 (from datasets)\n", + " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", + " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n", + "Collecting xxhash (from datasets)\n", + " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", + "Collecting multiprocess (from datasets)\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", + "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n", + "Requirement already satisfied: huggingface-hub>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.24.7)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.22.0->datasets) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", + "Downloading datasets-3.0.1-py3-none-any.whl (471 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m471.6/471.6 kB\u001b[0m \u001b[31m15.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n", + " Attempting uninstall: pyarrow\n", + " Found existing installation: pyarrow 14.0.2\n", + " Uninstalling pyarrow-14.0.2:\n", + " Successfully uninstalled pyarrow-14.0.2\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed datasets-3.0.1 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.5.0\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" + ] + } + ], + "source": [ + "!pip install transformers\n", + "!pip install datasets\n", + "!pip install torch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1vXYwybdRDD_", + "outputId": "2a471522-0bd6-47a3-d924-4bf30899a6d7" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.0.1)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.5)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.3)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.11.1)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" + ] + } + ], + "source": [ + "!pip install transformers datasets torch\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "##**Download and Load Dataset**" + ], + "metadata": { + "id": "wJx5vLJFHlxT" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5tiNer2wNmKd", + "outputId": "d76c300d-1387-4491-f2e7-426a8a375652" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.6.17)\n", + "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", + "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.8.30)\n", + "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.32.3)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.5)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n", + "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.2.3)\n", + "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", + "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", + "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.10)\n", + "cp: cannot stat 'kaggle.json': No such file or directory\n", + "chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory\n", + "Dataset URL: https://www.kaggle.com/datasets/rmisra/imdb-spoiler-dataset\n", + "License(s): Attribution 4.0 International (CC BY 4.0)\n", + "Downloading imdb-spoiler-dataset.zip to /content\n", + "100% 330M/331M [00:20<00:00, 20.6MB/s]\n", + "100% 331M/331M [00:20<00:00, 16.7MB/s]\n", + "Archive: imdb-spoiler-dataset.zip\n", + " inflating: IMDB_movie_details.json \n", + " inflating: IMDB_reviews.json \n", + "IMDB_movie_details.json IMDB_reviews.json imdb-spoiler-dataset.zip sample_data\n" + ] + } + ], + "source": [ + "!pip install kaggle\n", + "\n", + "!mkdir -p ~/.kaggle\n", + "!cp kaggle.json ~/.kaggle/\n", + "\n", + "!chmod 600 ~/.kaggle/kaggle.json\n", + "\n", + "!kaggle datasets download -d rmisra/imdb-spoiler-dataset\n", + "\n", + "!unzip imdb-spoiler-dataset.zip\n", + "\n", + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uXLdvqgVOL99", + "outputId": "658bd35a-2208-4869-b6f5-ea22d6eece6d" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " movie_id plot_summary duration \\\n", + "0 tt0105112 Former CIA analyst, Jack Ryan is in England wi... 1h 57min \n", + "1 tt1204975 Billy (Michael Douglas), Paddy (Robert De Niro... 1h 45min \n", + "2 tt0243655 The setting is Camp Firewood, the year 1981. I... 1h 37min \n", + "3 tt0040897 Fred C. Dobbs and Bob Curtin, both down on the... 2h 6min \n", + "4 tt0126886 Tracy Flick is running unopposed for this year... 1h 43min \n", + "\n", + " genre rating release_date \\\n", + "0 [Action, Thriller] 6.9 1992-06-05 \n", + "1 [Comedy] 6.6 2013-11-01 \n", + "2 [Comedy, Romance] 6.7 2002-04-11 \n", + "3 [Adventure, Drama, Western] 8.3 1948-01-24 \n", + "4 [Comedy, Drama, Romance] 7.3 1999-05-07 \n", + "\n", + " plot_synopsis \n", + "0 Jack Ryan (Ford) is on a \"working vacation\" in... \n", + "1 Four boys around the age of 10 are friends in ... \n", + "2 \n", + "3 Fred Dobbs (Humphrey Bogart) and Bob Curtin (T... \n", + "4 Jim McAllister (Matthew Broderick) is a much-a... \n", + " review_date movie_id user_id is_spoiler \\\n", + "0 10 February 2006 tt0111161 ur1898687 True \n", + "1 6 September 2000 tt0111161 ur0842118 True \n", + "2 3 August 2001 tt0111161 ur1285640 True \n", + "3 1 September 2002 tt0111161 ur1003471 True \n", + "4 20 May 2004 tt0111161 ur0226855 True \n", + "\n", + " review_text rating \\\n", + "0 In its Oscar year, Shawshank Redemption (writt... 10 \n", + "1 The Shawshank Redemption is without a doubt on... 10 \n", + "2 I believe that this film is the best story eve... 8 \n", + "3 **Yes, there are SPOILERS here**This film has ... 10 \n", + "4 At the heart of this extraordinary movie is a ... 8 \n", + "\n", + " review_summary \n", + "0 A classic piece of unforgettable film-making. \n", + "1 Simply amazing. The best film of the 90's. \n", + "2 The best story ever told on film \n", + "3 Busy dying or busy living? \n", + "4 Great story, wondrously told and acted \n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "movie_details = pd.read_json('IMDB_movie_details.json', lines=True)\n", + "reviews = pd.read_json('IMDB_reviews.json', lines=True)\n", + "print(movie_details.head())\n", + "print(reviews.head())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "44kpxJ0iOjLz" + }, + "outputs": [], + "source": [ + "# Drop rows where 'plot_synopsis' or 'plot_summary' is missing\n", + "movie_details.dropna(subset=['plot_synopsis', 'plot_summary'], inplace=True)\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "##**training and test sets**" + ], + "metadata": { + "id": "2OUyhuoiH_QV" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qaAPgonAPRGx" + }, + "outputs": [], + "source": [ + "\n", + "# Split the data into training and test sets\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "train_data, temp_data = train_test_split(movie_details, test_size=0.3, random_state=42) # 70% for training, 30% for temp\n", + "\n", + "validation_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42) # Split the 30% into 15% validation and 15% test\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gewp_maVVU6g", + "outputId": "884c7f25-053a-4106-feba-960386884107" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(1100, 7)" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ], + "source": [ + "train_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6Y-syQ3ZPVDR", + "outputId": "85ed66a3-8bb4-4d33-b394-8c293b55ffdd" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "count 1572.000000\n", + "mean 1439.085242\n", + "std 1496.392929\n", + "min 0.000000\n", + "25% 493.750000\n", + "50% 1073.500000\n", + "75% 1920.000000\n", + "max 11396.000000\n", + "Name: synopsis_length, dtype: float64\n" + ] + } + ], + "source": [ + "# Example exploration: Average length of plot_synopsis\n", + "movie_details['synopsis_length'] = movie_details['plot_synopsis'].apply(lambda x: len(x.split()))\n", + "print(movie_details['synopsis_length'].describe())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6lRJv7A7PX8x" + }, + "outputs": [], + "source": [ + "train_data.to_json('/content/train_data.json', orient='records', lines=True)\n", + "validation_data.to_json('/content/val_data.json', orient='records', lines=True)\n", + "test_data.to_json('/content/test_data.json', orient='records', lines=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "referenced_widgets": [ + "f98bee5f48254d9eabd10431adf5aa9b", + "6da7932ca9d749ce860e69c2a38c2541", + "abf3f9738aa944f1b1cd615b0e3f4a25", + "dfbc0cc3e4624d3d9cb25f063548d43a", + "4eccac116b3c4175a328b28ed4d63e13", + "8c6f5315e19c42369cec6cab6af10f92", + "8e681408f4574d7c96f6403475c82a6f", + "0ed0e435a6494dc896534036f84513bd", + "bcb2b97307814249b005f05b74edcee9", + "76e9bb11a1e64b02a191b6c415a90c41", + "fafce82ad5504cf5ab6f49180ee12522", + "2e568c6632ee43819924eeeb7c8c5739", + "acd1d2b5fa534f72968f02110462b118", + "0ece36fed4174eb198b1b9882c67fd58", + "fe8439b6f5bc4b228d2a6a446614a0cb", + "83628ad7a12242778b8bc106fad8c40c", + "320973fd444141d7b78fcef4c03a174e", + "828a7ff0740a482fb0034fa8d92fee81", + "fbe47b01fead44c099e12920dc8efe78", + "9085fbc279254b35a1e3b0e95bfb9ac9", + "b59a14bf014b4365a9dbd16da1b0ca3e", + "0411f712b615428fa6377aa47eeb8f12", + "5870843299214a8988fd4caaaf61ab28", + "7e1bb2578c2e4aa7a0c6b5fcdb997e6e", + "03ded09b8b5a4ebb869aeb65c87aa10c", + "130bc93a5870428ea6ee440cb6e21830", + "fca1ce2fe9a6497d903866a566c85a9c", + "537e0076644145b889196d291bde4374", + "5e80aefc2bb9415cb9c25c5e5219e38d", + "3334c22eeadd426085be9ae64b34aa46", + "b5574fe8a4cb47fca826b5dae823eb11", + "c8409f30e0bb42f892ec697783caaeeb", + "ca114842ea594e8291bd0a58e4652b0c" + ] + }, + "id": "yj5INXEXPhsr", + "outputId": "97f90268-e839-4106-9306-a221c4ca1b22" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating train split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "f98bee5f48254d9eabd10431adf5aa9b" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating test split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "2e568c6632ee43819924eeeb7c8c5739" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating val split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "5870843299214a8988fd4caaaf61ab28" + } + }, + "metadata": {} + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "data_files = {\n", + " 'train': '/content/train_data.json',\n", + " 'test': '/content/test_data.json',\n", + " 'val': '/content/val_data.json'\n", + "\n", + "\n", + "}\n", + "\n", + "# Loading the dataset from the JSON files\n", + "dataset = load_dataset('json', data_files=data_files, split={'train': 'train', 'test': 'test', 'val': 'val'})\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "##**Exploratory Data Analysis (EDA)**" + ], + "metadata": { + "id": "mPWl8kAFIjoG" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 411, + "referenced_widgets": [ + "40415712030a409194c8e5d8016ac438", + "2ff91beaccf84d198af6dd09fdd45a84", + "91281d1677d54499a36a5b52c7a5deb3", + "b3c29ef77d5f4a609209963ddeeb957c", + "3a89f17eb147489e95b020b719b5e9b5", + "b037f88d1faa4fedae99f99b2934b99c", + "33d23975450440e885275a946a62fcd5", + "14ddf24706744ae99151e1ef2ec50184", + "fc9528e659914c2abed75db51e6dbc46", + "8578b9081c104ae09f76c7941c8d55ac", + "8f995c8a721841059f49a7751702a264", + "48e0e360e8bd496daf07b9f41f62ab4c", + "4655a7b4746d4dbb94f11bc4bf649cd2", + "824f9d74bd9043d69c5fbd7b8670f12b", + "c270c9de8d1e4f1ba69fc2d11f41c5da", + "8a2f0bf77c5244db861a91df114fbfe4", + "7967f8faea82454782f928f5953171aa", + "1e252e01d9d24f928370ec596c1d2dce", + "0a2b31a22ea64959b4be31046c927710", + "d9cbfca658cd42c6bf4b272ae586212d", + "9ef8c17ef4b9440db46a2fa02f7366c5", + "c761198e022d4c32a7d1b79dc04967fa", + "33bc1b435a9e42b7a8e40538e4f4f670", + "ceeaf991d8c44123b758e290bfcdd9d1", + "46de1b040239439aaa1819972a5f854a", + "f8f00b01f89b46f4841c4b7b72adf7e4", + "ab495b67bf124b9da726fc12c68aace9", + "d2446351d77f4b27a06de442f5186a5c", + "52735ac507c04d2bbaa4c60461cfeef3", + "553fe0d427154df38dedf9c93ecfed0e", + "6decf9a706a14e7b8df71e6bcab55a06", + "1e0cc3e5eab8408aa2506edf11daa581", + "9ff6312e60c84aca9e8f3d1478f6b665" + ] + }, + "id": "GOe61CxiYkgv", + "outputId": "b5f95681-6609-4620-e64e-7b0309f2f14c" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/2.32k [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "count 1572.000000\n", + "mean 2116.767812\n", + "std 2197.103130\n", + "min 0.000000\n", + "25% 729.000000\n", + "50% 1581.000000\n", + "75% 2814.750000\n", + "max 18103.000000\n", + "Name: token_length, dtype: float64\n" + ] + } + ], + "source": [ + "from transformers import T5Tokenizer\n", + "\n", + "\n", + "import pandas as pd\n", + "\n", + "movie_details = pd.read_json('/content/IMDB_movie_details.json', lines=True)\n", + "\n", + "tokenizer = T5Tokenizer.from_pretrained('t5-small')\n", + "\n", + "def calculate_token_length(text):\n", + " return len(tokenizer.tokenize(text))\n", + "\n", + "movie_details['token_length'] = movie_details['plot_synopsis'].apply(calculate_token_length)\n", + "\n", + "stats = movie_details['token_length'].describe()\n", + "print(stats)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QHSO1Tz4Z5kc", + "outputId": "a13a07fa-69db-4429-97ff-6a9231f4e80e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "count 1572.000000\n", + "mean 151.191476\n", + "std 60.718672\n", + "min 20.000000\n", + "25% 103.000000\n", + "50% 142.000000\n", + "75% 195.250000\n", + "max 315.000000\n", + "Name: token_length, dtype: float64\n" + ] + } + ], + "source": [ + "# Apply the function to the plot_synopsis column\n", + "movie_details['token_length'] = movie_details['plot_summary'].apply(calculate_token_length)\n", + "\n", + "# Display statistics about the token lengths\n", + "stats = movie_details['token_length'].describe()\n", + "print(stats)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VS4ROiIKnvLy" + }, + "outputs": [], + "source": [ + "device = 'cuda'\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dTcbyQKVQFFA" + }, + "source": [ + "##**Preprocess the Data**" + ] + }, + { + "cell_type": "code", + "source": [ + "import re\n", + "import torch\n", + "from transformers import LEDForConditionalGeneration, LEDTokenizer\n", + "\n", + "# Load tokenizer and model\n", + "tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')\n", + "model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')\n", + "\n", + "model = model.to(device)\n", + "\n", + "# Function to normalize text\n", + "def normalize_text(text):\n", + " text = text.lower() # Lowercase the text\n", + " text = re.sub(r'\\s+', ' ', text).strip() # Remove extra spaces and newlines\n", + " text = re.sub(r'[^\\w\\s]', '', text) # Remove non-alphanumeric characters\n", + " return text\n", + "\n", + "# Preprocess function with normalization\n", + "def preprocess_function(examples):\n", + " # Normalize the plot_synopsis and plot_summary\n", + " inputs = [\"summarize: \" + normalize_text(doc) for doc in examples[\"plot_synopsis\"]]\n", + " model_inputs = tokenizer(inputs, max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", + "\n", + " # Normalize labels (plot_summary)\n", + " with tokenizer.as_target_tokenizer():\n", + " labels = tokenizer([normalize_text(doc) for doc in examples[\"plot_summary\"]], max_length=1024, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", + "\n", + " # Replace -100 for padding tokens in labels\n", + " labels[\"input_ids\"] = [\n", + " [(label if label != tokenizer.pad_token_id else -100) for label in lab]\n", + " for lab in labels[\"input_ids\"]\n", + " ]\n", + "\n", + " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", + " return model_inputs" + ], + "metadata": { + "id": "eBxLkLaUHJLI", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 296, + "referenced_widgets": [ + "4f57295463b143ec981c5b1d3a8e5a49", + "30bb9b323bb34b25b7bd94a57a08f947", + "f8f5ba727c104114b1aacbaf18be3b92", + "d787d652d4044ccb9529144d49e10c43", + "05fb81fc15954e0da2be232105a5fe34", + "06c3259f4c334ba19f8f3ec16c94dc71", + "ff91264eb09e470dbfab792e011b7c9d", + "66b831532b4e407687b49409b3f87767", + "3ef18fd318d04fb8b4c91b42707b4bd2", + "2885b8cfc2b741209c0caa0a8c9fa4c7", + "72f23b4514bf42b0a2475b98dfb7d7a7", + "daebf1aac7f54e05a22954dfc9ac8475", + "3502ffed2d904193a7f5cfbf9e46f941", + "052f3cd1f15d43d684b3f4bf1b315c46", + "b6000f50c0ae4023960afe1f54463203", + "58bfcde727374cb5b27e0eea0eae6579", + "cff8944496a9449a9d1e9269cbf6534f", + "c259b071ad53489d9dfa69283afc30e4", + "6820b5075d70492c8137a998c089ac05", + "0e6224d4897345e3be15d07e0a672de1", + "501ffe823844428085168dbb2adc9df8", + "af5d02cdb35a4061b8b45562c1fdc11e", + "63cc6928f2324ced805d8e94a94e16dc", + "f9578e0721f347d7aa26fbbb41c9e50a", + "08bc210b294f46718f9e2b05470c905d", + "93f990215b9645ff93075b963bcb012d", + "155b5d91def24487aab1cbd31112965a", + "0d57abcd6dcd41aeb92cdd86615f181a", + "4e782e5b17eb46d38b15c46de5260999", + "a5a7ad1bbad34587b312aa140ca09bb9", + "647ccaad654543f59d77d4a3ad789631", + "03f48376b86b4c2bbe698dbe6f908ac1", + "3e410bc98f834174ae3108a4b6eb6d0d", + "fd0f7478cf4946c3bdf5b597dc06b9a6", + "3c238acc24aa468fb18cddffe70beabf", + "54a06f05f47743d4ba4c030f7971163f", + "d18b831bcf614445a4988b98f190642e", + "d51f1868f20846cea882f30b799df9cc", + "1e8757c3bc244d6990c1e57f235bfa83", + "1622081b48d44e639c1e632724cb68e6", + "53257d25bb7d4daa92271af9a188206d", + "5d33d79ee7c14218a93dbea6a0ace598", + "8ab30d4ac61a45ca9731931080f1e5af", + "328440cf8dbf4f0396210c35b194d065", + "dda889cd884341e99555a04ca260c8d6", + "8308cf1b24a34c7ba35f84b1c23e2d9d", + "28b07b90d6494cccb9d77a37477e26d3", + "9a751a2507df4d2d938ffb7a0324a28f", + "23c64079b95245bba6b96d89bc73ec64", + "09c54c8238764d2fbbea9df2c60c750c", + "f273c230aea34e9eab790ac98dcf56a7", + "a86891ddccfa469288f12143c1630305", + "a5fea7e1c24448ba8de4db809cb8c535", + "fb7b6579ec8d41c4880225d75724e370", + "2a63c7005a504f068d143d9f7a80940d", + "03e69634d3364440b466a6135df1b6c5", + "fbef44cbcaa84d738e261c24d9eefd9d", + "1e3c017585d54afe96e0e2d03f7155e0", + "da966352362e45e88a18a8b04d4590d7", + "591e88da46b540e88b010d558e2e027b", + "2ea129aeb4ad4bae96bb490fc217eb2a", + "b28bb46030194be58630881f69524b9f", + "b7768d74494b47e2acfa6e153544783a", + "f4d5e8921d244f9dba213914b70b589e", + "ce38c13bfcff4214822873a92817004e", + "c6eec72b32b8426d959a704038c2312b", + "aff4182f85c44cd3828fd2e35321a5f1", + "bce088936ddf4e8abfe78e1191e96ff0", + "2ebd0b5ab9bd4c5a8c6e642c411f065d", + "5b9cd84929ed49b7b304bc5712c28a6f", + "072446b10e1045589106f92fbd95eafc", + "376ffd83dbbc4b23b71e838fe0713b94", + "16f244a4438a439f8df8803fb584cfab", + "106b6d087e5d4377a0a16b811f007015", + "5bf46c40aea848db8dde59773bdd53fa", + "0bd2fd1f6fc84f7e81291f558832706b", + "12489aa45cfd4f7ba0c471b8f5a09d58" + ] + }, + "outputId": "8b840812-4822-433b-d36d-b8430fd33e5f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/27.0 [00:00\n", + " \n", + " \n", + " [2001/3300 47:25 < 30:49, 0.70 it/s, Epoch 1.82/3]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
12.9253002.873685

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [3300/3300 1:23:26, Epoch 3/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
12.9253002.873685
22.4096002.883492
32.1653002.920285

" + ] + }, + "metadata": {} + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TrainOutput(global_step=3300, training_loss=2.429696747750947, metrics={'train_runtime': 5007.437, 'train_samples_per_second': 0.659, 'train_steps_per_second': 0.659, 'total_flos': 6526373990400000.0, 'train_loss': 2.429696747750947, 'epoch': 3.0})" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "##**Save the Model (3 epochs trained)**" + ], + "metadata": { + "id": "uBp1rPWNI8l9" + } + }, + { + "cell_type": "code", + "source": [ + "model_save_path = '/content/drive/MyDrive/summary_generation_Led_3'\n", + "trainer.save_model(model_save_path)" + ], + "metadata": { + "id": "kFuNP98wGaa2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "# Load the trained model from Google Drive\n", + "model_save_path = '/content/drive/MyDrive/summary_generation_Led_3'\n", + "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", + "\n", + "model = model.to(device)\n" + ], + "metadata": { + "id": "yopwoSEJgNkT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "##**Continue training for 1 additional epoch**" + ], + "metadata": { + "id": "T8XIu4hFJSQk" + } + }, + { + "cell_type": "code", + "source": [ + "from transformers import Trainer, TrainingArguments\n", + "\n", + "# Reload the model from the saved checkpoint in Google Drive\n", + "model_save_path = '/content/drive/MyDrive/summary_generation' # Your saved path\n", + "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", + "model = model.to(device) # Ensure the model is on the right device (GPU or CPU)\n", + "\n", + "# Define the training arguments for continuing training\n", + "training_args = TrainingArguments(\n", + " output_dir=\"./results\",\n", + " eval_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " learning_rate=2e-5,\n", + " per_device_train_batch_size=1, # LED is memory-intensive\n", + " per_device_eval_batch_size=1,\n", + " weight_decay=0.01,\n", + " save_total_limit=3, # Only keep the last 3 models saved\n", + " num_train_epochs=1, # Continue training for 1 additional epoch\n", + " report_to=\"none\",\n", + " logging_dir='./logs', # Directory for storing logs\n", + " logging_steps=500,\n", + " load_best_model_at_end=True,\n", + "\n", + "# Initialize the optimizer (AdamW) again for the new training run\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_datasets[\"train\"],\n", + " eval_dataset=tokenized_datasets[\"val\"],\n", + " tokenizer=tokenizer,\n", + " optimizers=(optimizer, None)\n", + ")\n", + "\n", + "trainer.train()\n", + "\n", + "trainer.save_model(\"/content/drive/MyDrive/summary_generation_Led_4\") # Save the new checkpoint\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 144 + }, + "id": "DnDXbarJm1YG", + "outputId": "90c5c5a8-bf06-4631-e868-17c6e7dc2f4a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1100/1100 26:32, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
12.1405002.964694

" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "There were missing keys in the checkpoint model loaded: ['led.encoder.embed_tokens.weight', 'led.decoder.embed_tokens.weight', 'lm_head.weight'].\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "##**Evaluation on test data**" + ], + "metadata": { + "id": "ewfMVjdEJoWF" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7S7rMmsRzUDK" + }, + "outputs": [], + "source": [ + "pip install rouge_score" + ] + }, + { + "cell_type": "code", + "source": [ + "pip install nltk\n" + ], + "metadata": { + "id": "OFJmH2KUtdod" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import torch, gc\n", + "\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n" + ], + "metadata": { + "id": "SEGn9IQ4to0c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import nltk\n", + "\n", + "nltk.download('punkt')\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FliiA5j_-bK6", + "outputId": "a1784ce0-e243-4f06-b2dc-554300ec8726" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[nltk_data] Downloading package punkt to /root/nltk_data...\n", + "[nltk_data] Unzipping tokenizers/punkt.zip.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "code", + "source": [ + "from datasets import load_metric\n", + "import nltk\n", + "\n", + "metric_rouge = load_metric(\"rouge\")\n", + "\n", + "def generate_summary(batch):\n", + " inputs = tokenizer(batch[\"plot_synopsis\"], max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", + " inputs = inputs.to(device)\n", + " outputs = model.generate(inputs[\"input_ids\"], max_length=315, min_length=20, length_penalty=2.0, num_beams=4, early_stopping=True)\n", + "\n", + " batch[\"pred_summary\"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + " return batch\n", + "\n", + "results = dataset[\"test\"].map(generate_summary, batched=True, batch_size=8)\n", + "rouge_score = metric_rouge.compute(predictions=results[\"pred_summary\"], references=results[\"plot_summary\"])\n", + "print(\"ROUGE scores:\")\n", + "print(rouge_score)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 104, + "referenced_widgets": [ + "ad570df7408d48d99eb69a50ad8f3201", + "265654d5426b4616b530e7fd1c7cec04", + "ccb396e3506e487284decd31875fa33d", + "2873376fa2264b26ab8dcebe41d46af0", + "13b89ba59e80499f9d6926ee97d9f519", + "edf3f110599943ab822847b15bcfd0d4", + "8d275ca50cce4c48889ddc50769ada21", + "d89b8b1072d9408c96d2f7b2e735c785", + "0e6ac91d7b284b9b82abdedf36e91924", + "e0b2c384ca364c49869ae498683a6ef3", + "a84982a5e3ff40a18ab9b416cd85a111" + ] + }, + "id": "I1wd38l3sYAy", + "outputId": "4a6e0d20-74b4-4916-edf3-c7b0ef4b5117" + }, + "execution_count": null, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad570df7408d48d99eb69a50ad8f3201", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/236 [00:00