diff --git "a/final_summary_generation_LED (1).ipynb" "b/final_summary_generation_LED (1).ipynb" new file mode 100644--- /dev/null +++ "b/final_summary_generation_LED (1).ipynb" @@ -0,0 +1,5954 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VtuWHaKEQdEq", + "outputId": "9f28174a-c296-4af7-a700-e143970403e1" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.42.4)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.5)\n", + "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.4)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.6.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.7.4)\n", + "Collecting datasets\n", + " Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.15.4)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", + "Collecting pyarrow>=15.0.0 (from datasets)\n", + " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", + " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n", + "Collecting xxhash (from datasets)\n", + " Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", + "Collecting multiprocess (from datasets)\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", + "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.2)\n", + "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.5)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.3.5)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.2->datasets) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.7.4)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", + "Downloading datasets-2.21.0-py3-none-any.whl (527 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m527.3/527.3 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m19.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n", + " Attempting uninstall: pyarrow\n", + " Found existing installation: pyarrow 14.0.2\n", + " Uninstalling pyarrow-14.0.2:\n", + " Successfully uninstalled pyarrow-14.0.2\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n", + "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed datasets-2.21.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.4.1\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n", + " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n", + " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n", + " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n", + " Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n", + " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n", + " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n", + " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", + "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n", + " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n", + " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", + "Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n", + " Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n", + "Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n", + " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n", + "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n", + "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n", + " Using cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", + "Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", + "Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", + "Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", + "Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", + "Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", + "Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", + "Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", + "Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", + "Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n", + "Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", + "Using cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl (19.7 MB)\n", + "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", + "Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.20 nvidia-nvtx-cu12-12.1.105\n" + ] + } + ], + "source": [ + "!pip install transformers\n", + "!pip install datasets\n", + "!pip install torch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1vXYwybdRDD_", + "outputId": "19a76d2e-3b8e-4aba-a6a1-2e410b5806bb" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.42.4)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.21.0)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.5)\n", + "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.4)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.2)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.20.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.20)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.3.5)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.7.4)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" + ] + } + ], + "source": [ + "!pip install transformers datasets torch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5tiNer2wNmKd", + "outputId": "9cb6a807-b20b-4807-c64e-71bd89b09f3a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.6.17)\n", + "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", + "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.7.4)\n", + "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.32.3)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.5)\n", + "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n", + "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n", + "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", + "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", + "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.7)\n", + "cp: cannot stat 'kaggle.json': No such file or directory\n", + "chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory\n", + "Dataset URL: https://www.kaggle.com/datasets/rmisra/imdb-spoiler-dataset\n", + "License(s): Attribution 4.0 International (CC BY 4.0)\n", + "Downloading imdb-spoiler-dataset.zip to /content\n", + " 98% 325M/331M [00:02<00:00, 136MB/s]\n", + "100% 331M/331M [00:02<00:00, 138MB/s]\n", + "Archive: imdb-spoiler-dataset.zip\n", + " inflating: IMDB_movie_details.json \n", + " inflating: IMDB_reviews.json \n", + "IMDB_movie_details.json IMDB_reviews.json imdb-spoiler-dataset.zip sample_data\n" + ] + } + ], + "source": [ + "!pip install kaggle\n", + "\n", + "!mkdir -p ~/.kaggle\n", + "!cp kaggle.json ~/.kaggle/\n", + "\n", + "!chmod 600 ~/.kaggle/kaggle.json\n", + "\n", + "!kaggle datasets download -d rmisra/imdb-spoiler-dataset\n", + "\n", + "!unzip imdb-spoiler-dataset.zip\n", + "\n", + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uXLdvqgVOL99", + "outputId": "d6231cf9-2178-4c04-b880-53ef40842eea" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " movie_id plot_summary duration \\\n", + "0 tt0105112 Former CIA analyst, Jack Ryan is in England wi... 1h 57min \n", + "1 tt1204975 Billy (Michael Douglas), Paddy (Robert De Niro... 1h 45min \n", + "2 tt0243655 The setting is Camp Firewood, the year 1981. I... 1h 37min \n", + "3 tt0040897 Fred C. Dobbs and Bob Curtin, both down on the... 2h 6min \n", + "4 tt0126886 Tracy Flick is running unopposed for this year... 1h 43min \n", + "\n", + " genre rating release_date \\\n", + "0 [Action, Thriller] 6.9 1992-06-05 \n", + "1 [Comedy] 6.6 2013-11-01 \n", + "2 [Comedy, Romance] 6.7 2002-04-11 \n", + "3 [Adventure, Drama, Western] 8.3 1948-01-24 \n", + "4 [Comedy, Drama, Romance] 7.3 1999-05-07 \n", + "\n", + " plot_synopsis \n", + "0 Jack Ryan (Ford) is on a \"working vacation\" in... \n", + "1 Four boys around the age of 10 are friends in ... \n", + "2 \n", + "3 Fred Dobbs (Humphrey Bogart) and Bob Curtin (T... \n", + "4 Jim McAllister (Matthew Broderick) is a much-a... \n", + " review_date movie_id user_id is_spoiler \\\n", + "0 10 February 2006 tt0111161 ur1898687 True \n", + "1 6 September 2000 tt0111161 ur0842118 True \n", + "2 3 August 2001 tt0111161 ur1285640 True \n", + "3 1 September 2002 tt0111161 ur1003471 True \n", + "4 20 May 2004 tt0111161 ur0226855 True \n", + "\n", + " review_text rating \\\n", + "0 In its Oscar year, Shawshank Redemption (writt... 10 \n", + "1 The Shawshank Redemption is without a doubt on... 10 \n", + "2 I believe that this film is the best story eve... 8 \n", + "3 **Yes, there are SPOILERS here**This film has ... 10 \n", + "4 At the heart of this extraordinary movie is a ... 8 \n", + "\n", + " review_summary \n", + "0 A classic piece of unforgettable film-making. \n", + "1 Simply amazing. The best film of the 90's. \n", + "2 The best story ever told on film \n", + "3 Busy dying or busy living? \n", + "4 Great story, wondrously told and acted \n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "# Load JSON data by reading lines if the data is stored in JSON Lines format\n", + "movie_details = pd.read_json('IMDB_movie_details.json', lines=True)\n", + "reviews = pd.read_json('IMDB_reviews.json', lines=True)\n", + "\n", + "# Check the first few entries to ensure data is loaded correctly\n", + "print(movie_details.head())\n", + "print(reviews.head())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "44kpxJ0iOjLz" + }, + "outputs": [], + "source": [ + "# Drop rows where 'plot_synopsis' or 'plot_summary' is missing\n", + "movie_details.dropna(subset=['plot_synopsis', 'plot_summary'], inplace=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qaAPgonAPRGx" + }, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Split the data into training and test sets\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "# First split: split into training and temp data (which will become validation and test sets)\n", + "train_data, temp_data = train_test_split(movie_details, test_size=0.3, random_state=42) # 70% for training, 30% for temp\n", + "\n", + "# Second split: split the temp data into validation and test sets\n", + "validation_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42) # Split the 30% into 15% validation and 15% test\n", + "\n", + "# Now, train_data holds 70% of the data, validation_data holds 15%, and test_data also holds 15%\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gewp_maVVU6g", + "outputId": "281cc145-2564-466c-b526-1ad7f58957c1" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(1100, 7)" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ], + "source": [ + "train_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6Y-syQ3ZPVDR", + "outputId": "544030c8-9652-4ec6-96f3-b3dbd03a82a4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "count 1572.000000\n", + "mean 1439.085242\n", + "std 1496.392929\n", + "min 0.000000\n", + "25% 493.750000\n", + "50% 1073.500000\n", + "75% 1920.000000\n", + "max 11396.000000\n", + "Name: synopsis_length, dtype: float64\n" + ] + } + ], + "source": [ + "# Example exploration: Average length of plot_synopsis\n", + "movie_details['synopsis_length'] = movie_details['plot_synopsis'].apply(lambda x: len(x.split()))\n", + "print(movie_details['synopsis_length'].describe())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6lRJv7A7PX8x" + }, + "outputs": [], + "source": [ + "train_data.to_json('/content/train_data.json', orient='records', lines=True)\n", + "validation_data.to_json('/content/val_data.json', orient='records', lines=True)\n", + "test_data.to_json('/content/test_data.json', orient='records', lines=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 113, + "referenced_widgets": [ + "7e20cd41d4d7401284eae05380531c86", + "1f7a1424873e47f9882ce9aadd360be9", + "3ba09083cc8f403c8231adb1dddeffcd", + "bf60ee240ad141e5b7c14a31caf459f6", + "46535b3c982d45aeadcaa00f2e024ad1", + "e19c05d1e02546ddad481a7d5a5917cd", + "ab60068105e149eb9318b3a8b3662075", + "6ec9693194af46ffbd8310d83917803b", + "5f88faea4b894a6e8cb44ce00084e2a2", + "eeecd29be4284e0cb9b3ddf9d83de029", + "43830d9d88dd47bfb4ce27020a2be9f3", + "481fda13f40a4ea19853b2b0cf0fa99e", + "a16b6c84f2c74226a46b1e9724965cca", + "2e55a4a146454d2faacfbfe352ecd917", + "c8b01fd70f364abdb2da85e135dfc5d2", + "974c5fa2dc1c47ae95b8debec2640825", + "e98b2b076dba4a8ab1942dc0e23984d8", + "2feee173fbe44801bbc9dc05952e1038", + "ffb79557f11948d69d0b67f2a71121e8", + "91ef540aa06b43d28c6bd371665dcb68", + "e76db42f1f9e4edda5ff102e8cdd7363", + "9330377c22224352b99bba7cd4d635b2", + "3de6d461f83a4c5f8924eb02b9f13d84", + "d4adf688973c48a693a7a4aee8e0b23e", + "756a26793f594047b99f5642a59b6dd9", + "924058735d454e47a76f36c558a7ba15", + "edfc526ced1b4f799f4609a4289e369b", + "b2ebfcadca8f4017ac1606caeba957ed", + "ec119eb9de4742d180db9862b5369b70", + "8a32a28a12464bb683c62d10a9663f1c", + "5317f74358914b1783f2bb812006cd8d", + "9293e46df3f14358b87af38eeef7f109", + "24c79207d28d495eaeabd38db4f6007b" + ] + }, + "id": "yj5INXEXPhsr", + "outputId": "c1d93ee4-7dbc-45ed-c12b-f8cca44398ed" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating train split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "7e20cd41d4d7401284eae05380531c86" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating test split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "481fda13f40a4ea19853b2b0cf0fa99e" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Generating val split: 0 examples [00:00, ? examples/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "3de6d461f83a4c5f8924eb02b9f13d84" + } + }, + "metadata": {} + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "data_files = {\n", + " 'train': '/content/train_data.json',\n", + " 'test': '/content/test_data.json',\n", + " 'val': '/content/val_data.json'\n", + "\n", + "\n", + "}\n", + "\n", + "# Loading the dataset from the JSON files\n", + "dataset = load_dataset('json', data_files=data_files, split={'train': 'train', 'test': 'test', 'val': 'val'})\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 428, + "referenced_widgets": [ + "57fcbc7a221e4441a1908c756e3b2ca0", + "b733e988bb9542759d8d834ff6b1d5f3", + "a7e4c7ab193e4300acc49050c2004b1b", + "ae8003c16eca4d59962125d599f79f31", + "ef6d361ff5e34115b55443e499685937", + "f73ebe2ab28242efbb4450a6e329086a", + "2181e80c121d4da8bd422e1d2be8d5b8", + "beac2ecdbef440e799f4a1364f7e8856", + "bd75bd7a93084f39b1cbe9d2b789c3fc", + "43a6048276494abb814a078fe05c439f", + "80cdedbc980d490fae438a6fb92e0e66", + "f542f154196d4ef2bf67ffbdc488b8ad", + "83772cf85c604db19217acb073ff4f78", + "d42835fe1966412cbaa21333ecd361cc", + "222c0334528541d5b4470801a75b8d0e", + "4107eaad71e04e968c9d4123b7e70312", + "5274a20615a949c68eee15c066f845c1", + "5193d8adb6124549a4d9471c2475f1ec", + "11dc881125294407b118b31d01b92d70", + "14435f717b2c43e59861a7b404d11b53", + "77f710ea78b04ba3831b6f95f156fe31", + "9ede3c68111047d484edfbd67c82c50d", + "1d5b88a9e165423a815479dd4138bf2b", + "6b93a696e83e401787113078df83baa7", + "654e846c19a441d18ea81a12350c9c51", + "0f751527e05e4293be2b2312cc52d754", + "a54d5c1d98cb46a68a38da67d9118658", + "9c374f90a9b74b208c54f4958d40a12c", + "21c8478187654918ab9c689fb31ee16e", + "63ede400669443c78551c04691c422e6", + "fc64e38fac154cc3b48989a126d73699", + "f285da915ded4694aa9b41c03f3bd289", + "be05ebd4527c4bb08da9e0ff2e7b892e" + ] + }, + "id": "GOe61CxiYkgv", + "outputId": "52095422-f4f9-4053-caf3-40ac92c981ae" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/2.32k [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "count 1572.000000\n", + "mean 2116.767812\n", + "std 2197.103130\n", + "min 0.000000\n", + "25% 729.000000\n", + "50% 1581.000000\n", + "75% 2814.750000\n", + "max 18103.000000\n", + "Name: token_length, dtype: float64\n" + ] + } + ], + "source": [ + "from transformers import T5Tokenizer\n", + "\n", + "\n", + "import pandas as pd\n", + "\n", + "movie_details = pd.read_json('/content/IMDB_movie_details.json', lines=True)\n", + "\n", + "# Initialize the tokenizer\n", + "tokenizer = T5Tokenizer.from_pretrained('t5-small')\n", + "\n", + "# Function to calculate token length\n", + "def calculate_token_length(text):\n", + " return len(tokenizer.tokenize(text))\n", + "\n", + "# Apply the function to the plot_synopsis column\n", + "movie_details['token_length'] = movie_details['plot_synopsis'].apply(calculate_token_length)\n", + "\n", + "# Display statistics about the token lengths\n", + "stats = movie_details['token_length'].describe()\n", + "print(stats)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "QHSO1Tz4Z5kc", + "outputId": "d64d5657-b3c9-42b1-e5bb-c47067593d30" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "count 1572.000000\n", + "mean 151.191476\n", + "std 60.718672\n", + "min 20.000000\n", + "25% 103.000000\n", + "50% 142.000000\n", + "75% 195.250000\n", + "max 315.000000\n", + "Name: token_length, dtype: float64\n" + ] + } + ], + "source": [ + "# Apply the function to the plot_synopsis column\n", + "movie_details['token_length'] = movie_details['plot_summary'].apply(calculate_token_length)\n", + "\n", + "# Display statistics about the token lengths\n", + "stats = movie_details['token_length'].describe()\n", + "print(stats)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VS4ROiIKnvLy" + }, + "outputs": [], + "source": [ + "device = 'cuda'\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dTcbyQKVQFFA" + }, + "source": [ + "##**Preprocess the Data**" + ] + }, + { + "cell_type": "code", + "source": [ + "import re\n", + "import torch\n", + "from transformers import LEDForConditionalGeneration, LEDTokenizer\n", + "\n", + "# Load tokenizer and model\n", + "tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')\n", + "model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')\n", + "\n", + "model = model.to(device)\n", + "\n", + "# Function to normalize text\n", + "def normalize_text(text):\n", + " text = text.lower() # Lowercase the text\n", + " text = re.sub(r'\\s+', ' ', text).strip() # Remove extra spaces and newlines\n", + " text = re.sub(r'[^\\w\\s]', '', text) # Remove non-alphanumeric characters\n", + " return text\n", + "\n", + "# Preprocess function with normalization\n", + "def preprocess_function(examples):\n", + " # Normalize the plot_synopsis and plot_summary\n", + " inputs = [\"summarize: \" + normalize_text(doc) for doc in examples[\"plot_synopsis\"]]\n", + " model_inputs = tokenizer(inputs, max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", + "\n", + " # Normalize labels (plot_summary)\n", + " with tokenizer.as_target_tokenizer():\n", + " labels = tokenizer([normalize_text(doc) for doc in examples[\"plot_summary\"]], max_length=1024, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", + "\n", + " # Replace -100 for padding tokens in labels\n", + " labels[\"input_ids\"] = [\n", + " [(label if label != tokenizer.pad_token_id else -100) for label in lab]\n", + " for lab in labels[\"input_ids\"]\n", + " ]\n", + "\n", + " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", + " return model_inputs" + ], + "metadata": { + "id": "eBxLkLaUHJLI" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NUPy3CuAQBGR", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 319, + "referenced_widgets": [ + "83ee5845ab11475c822bef2a97d9985e", + "df49216e7e97470c8c7116f81bd9088a", + "4fbf399a047a42389247bf044ed0407a", + "165bbe074c97440387978c5489cc01c0", + "84ab4978aa1c4fc19359fd7890b2ff93", + "f704393ca576446dbe26a9d957d1b8af", + "ca6ac6fa765c42e38ede8f42ffaecd89", + "aacf1d2128a9422b948d5fa8cda21992", + "75c04f957217426d970c64e1373292a0", + "e5fbed894a9944179c8d98abf67c9db9", + "f6fcf54ded9f48a89882d2054e421b8f", + "f04ea068b2bc4462bb24d8ef032bb244", + "38b648064bc84d1fa81de179271b1a3a", + "7090cc633e4d4fc5a64dcd5c5441edfb", + "b35d5fcaf2824c729229736ea2c64bfd", + "f78f2e10e3364c6aa336a71b2c8e1375", + "8c998cbdd0d248b588566b6905691ccb", + "c0ccc4d33aa14e53a763de90417d512d", + "70ea062d17e444b19aecda8a98090ac2", + "3fb5e4ae669f44b19418ddd3be5e3f6c", + "8674604874294cdda360acfa42281143", + "a57f1446a67a4e749a6aa15c03942e05", + "bc77626dc02841b1a4a8bb1fd736b069", + "c2cab2d64c094889a3fff1088c2fcc98", + "0abb861eef7c4ef983e9415347d534a0", + "e9f3c13b73624b48abbc80c5bea466da", + "66fa95cffd024444bcfb222ea43e7f27", + "aa62cd6d9c9848e2a4d78d905540e6c3", + "18c95b540b7644658719d73e7f3e901b", + "acf2edbc822a4370879ed621aad67c32", + "5aaea6aede184582bf0d0157039af114", + "3ebd4a88dc0d483d973a35add37e5772", + "7831e2aa3c644d0ea83694f4a6816727", + "9c79e38bc1454c52865a0628ff687276", + "820758d6fa1f4c07a9748e77dab01311", + "32b8b010edde4ca7ac47047d1811f38b", + "879091262d234c6ab761dd712d48d4dc", + "44df2ef9fe804e4798e0bafcaa2f96f8", + "a8e7085f3a1e46f495785b0aa330cae2", + "28e86cffe6d94da19892b3a1aab1da14", + "ec1275ef768f43698c342a9dc81fafcc", + "a1328d66ceeb4e3c90ee715a562750de", + "e92c2c6741384fe38ce8a7ec42344fbb", + "4c212586a87d43b09d7085dee3cd1f98", + "3ad8fcdbf3cf4234bef3294087c250bf", + "bff77597434e408cb5a74dfb49121408", + "2b257a4d028145e284cf91ccdb341002", + "5818389eeb52497ea30392c67f4ee098", + "6ea4fd420cf24e0eb06597d90dfa5a6f", + "5cae3942022845bca747d27c148f8013", + "b803692080204ebc923925dbeee49728", + "8c70de273ecf42d1ab0541d613122974", + "82382da94cf14bbbb08f7879dbe8ee7b", + "8a4b4e7b68fa44df8c793e01961e11b2", + "29e4073fe6a94b62b3d10419c2a76a52", + "06a0ac0b2ac847f39fc4123df7b6c0cb", + "ce391154f73d4c03bf5ca5a770f4528d", + "0ec1f324dcd046fb8c6e1115087184bd", + "57d530b1bd464a5da1c6fa3b05b24fc7", + "aa463fb674994d449a5a8fc7ccde4a4d" + ] + }, + "outputId": "d98eb8e3-3eb7-42ef-86b5-f75c0723d804" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "83ee5845ab11475c822bef2a97d9985e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/27.0 [00:00\n", + " \n", + " \n", + " [2001/3300 47:25 < 30:49, 0.70 it/s, Epoch 1.82/3]\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
12.9253002.873685

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [3300/3300 1:23:26, Epoch 3/3]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
12.9253002.873685
22.4096002.883492
32.1653002.920285

" + ] + }, + "metadata": {} + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TrainOutput(global_step=3300, training_loss=2.429696747750947, metrics={'train_runtime': 5007.437, 'train_samples_per_second': 0.659, 'train_steps_per_second': 0.659, 'total_flos': 6526373990400000.0, 'train_loss': 2.429696747750947, 'epoch': 3.0})" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gt4LnnFMTTsV", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e513a352-837f-4d06-c635-fac1596c342e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n" + ] + } + ], + "source": [ + "# trainer.train()\n", + "# trainer.save_model(\"/content/t5_spoiler_free_summarization\")\n", + "# prompt: can you also save it in my driver?\n", + "\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "# Assuming the model is already trained and 'model' variable holds the trained model\n", + "model_save_path = '/content/drive/MyDrive/summary_generation' # Replace with your desired path in Drive\n", + "trainer.save_model(model_save_path)\n" + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "p1oX9EDsmxPo" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Assuming the model is already trained and 'model' variable holds the trained model\n", + "model_save_path = '/content/drive/MyDrive/summary_generation_Led_3' # Replace with your desired path in Drive\n", + "trainer.save_model(model_save_path)" + ], + "metadata": { + "id": "kFuNP98wGaa2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "# Load the trained model from Google Drive\n", + "model_save_path = '/content/drive/MyDrive/summary_generation' # Replace with your saved path\n", + "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", + "\n", + "# Ensure model is on the right device (GPU if available)\n", + "model = model.to(device)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yopwoSEJgNkT", + "outputId": "e5cf03ef-5d32-4447-9b6d-237c90b4b3c0" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from transformers import Trainer, TrainingArguments\n", + "\n", + "# Reload the model from the saved checkpoint in Google Drive\n", + "model_save_path = '/content/drive/MyDrive/summary_generation' # Your saved path\n", + "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", + "model = model.to(device) # Ensure the model is on the right device (GPU or CPU)\n", + "\n", + "# Define the training arguments for continuing training\n", + "training_args = TrainingArguments(\n", + " output_dir=\"./results\", # Where the model checkpoints will be stored\n", + " eval_strategy=\"epoch\", # Use eval_strategy instead of evaluation_strategy\n", + " save_strategy=\"epoch\", # Ensure the save strategy matches the evaluation strategy\n", + " learning_rate=2e-5, # Adjust learning rate if needed\n", + " per_device_train_batch_size=1, # LED is memory-intensive\n", + " per_device_eval_batch_size=1,\n", + " weight_decay=0.01,\n", + " save_total_limit=3, # Only keep the last 3 models saved\n", + " num_train_epochs=1, # Continue training for 1 additional epoch\n", + " report_to=\"none\",\n", + " logging_dir='./logs', # Directory for storing logs\n", + " logging_steps=500,\n", + " load_best_model_at_end=True, # Load best model at the end of training based on evaluation loss\n", + ")\n", + "\n", + "# Initialize the optimizer (AdamW) again for the new training run\n", + "optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)\n", + "\n", + "# Create the Trainer instance and pass the model, data, and optimizer\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_datasets[\"train\"], # The same training dataset\n", + " eval_dataset=tokenized_datasets[\"val\"], # The same validation dataset\n", + " tokenizer=tokenizer,\n", + " optimizers=(optimizer, None) # Pass the optimizer\n", + ")\n", + "\n", + "# Resume training for another epoch\n", + "trainer.train()\n", + "\n", + "# After training is complete, you can save the updated model\n", + "trainer.save_model(\"/content/drive/MyDrive/summary_generation_Led_4\") # Save the new checkpoint\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 144 + }, + "id": "DnDXbarJm1YG", + "outputId": "90c5c5a8-bf06-4631-e868-17c6e7dc2f4a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [1100/1100 26:32, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss
12.1405002.964694

" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "There were missing keys in the checkpoint model loaded: ['led.encoder.embed_tokens.weight', 'led.decoder.embed_tokens.weight', 'lm_head.weight'].\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7S7rMmsRzUDK", + "outputId": "207767ac-c55d-433b-989f-6f9f15098d16" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting rouge_score\n", + " Downloading rouge_score-0.1.2.tar.gz (17 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.4.0)\n", + "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge_score) (3.8.1)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.26.4)\n", + "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.16.0)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (8.1.7)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (1.4.2)\n", + "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (2024.5.15)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (4.66.5)\n", + "Building wheels for collected packages: rouge_score\n", + " Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=99631a27f739be8626953540b4e8f0dccd306bb718d08af5f91069d975cf0f26\n", + " Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n", + "Successfully built rouge_score\n", + "Installing collected packages: rouge_score\n", + "Successfully installed rouge_score-0.1.2\n" + ] + } + ], + "source": [ + "pip install rouge_score" + ] + }, + { + "cell_type": "code", + "source": [ + "pip install nltk\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OFJmH2KUtdod", + "outputId": "8b551282-a6ac-481f-f5fe-d6c702e2ba70" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.8.1)\n", + "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.7)\n", + "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk) (1.4.2)\n", + "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk) (2024.5.15)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk) (4.66.5)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import torch, gc\n", + "\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n" + ], + "metadata": { + "id": "SEGn9IQ4to0c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import nltk\n", + "\n", + "nltk.download('punkt')\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FliiA5j_-bK6", + "outputId": "a1784ce0-e243-4f06-b2dc-554300ec8726" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "[nltk_data] Downloading package punkt to /root/nltk_data...\n", + "[nltk_data] Unzipping tokenizers/punkt.zip.\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ] + }, + { + "cell_type": "code", + "source": [ + "from datasets import load_metric\n", + "import nltk\n", + "\n", + "metric_rouge = load_metric(\"rouge\")\n", + "\n", + "def generate_summary(batch):\n", + " inputs = tokenizer(batch[\"plot_synopsis\"], max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", + " inputs = inputs.to(device)\n", + " outputs = model.generate(inputs[\"input_ids\"], max_length=315, min_length=20, length_penalty=2.0, num_beams=4, early_stopping=True)\n", + "\n", + " batch[\"pred_summary\"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", + " return batch\n", + "\n", + "results = dataset[\"test\"].map(generate_summary, batched=True, batch_size=8)\n", + "rouge_score = metric_rouge.compute(predictions=results[\"pred_summary\"], references=results[\"plot_summary\"])\n", + "print(\"ROUGE scores:\")\n", + "print(rouge_score)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 104, + "referenced_widgets": [ + "ad570df7408d48d99eb69a50ad8f3201", + "265654d5426b4616b530e7fd1c7cec04", + "ccb396e3506e487284decd31875fa33d", + "2873376fa2264b26ab8dcebe41d46af0", + "13b89ba59e80499f9d6926ee97d9f519", + "edf3f110599943ab822847b15bcfd0d4", + "8d275ca50cce4c48889ddc50769ada21", + "d89b8b1072d9408c96d2f7b2e735c785", + "0e6ac91d7b284b9b82abdedf36e91924", + "e0b2c384ca364c49869ae498683a6ef3", + "a84982a5e3ff40a18ab9b416cd85a111" + ] + }, + "id": "I1wd38l3sYAy", + "outputId": "4a6e0d20-74b4-4916-edf3-c7b0ef4b5117" + }, + "execution_count": null, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ad570df7408d48d99eb69a50ad8f3201", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/236 [00:00