diff --git "a/base LSTM model.ipynb" "b/base LSTM model.ipynb" new file mode 100644--- /dev/null +++ "b/base LSTM model.ipynb" @@ -0,0 +1,1950 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "tYNpi1MX7FmW" + }, + "source": [ + "
Authors: Ali Nikkhah, Ramtin Khoshnevis, Sarina Zahedi
\n", + "(Equal Contribution)
\n", + "\n", + " | review_date | \n", + "movie_id | \n", + "user_id | \n", + "is_spoiler | \n", + "review_text | \n", + "rating | \n", + "review_summary | \n", + "
---|---|---|---|---|---|---|---|
0 | \n", + "10 February 2006 | \n", + "tt0111161 | \n", + "ur1898687 | \n", + "True | \n", + "In its Oscar year, Shawshank Redemption (writt... | \n", + "10 | \n", + "A classic piece of unforgettable film-making. | \n", + "
1 | \n", + "6 September 2000 | \n", + "tt0111161 | \n", + "ur0842118 | \n", + "True | \n", + "The Shawshank Redemption is without a doubt on... | \n", + "10 | \n", + "Simply amazing. The best film of the 90's. | \n", + "
2 | \n", + "3 August 2001 | \n", + "tt0111161 | \n", + "ur1285640 | \n", + "True | \n", + "I believe that this film is the best story eve... | \n", + "8 | \n", + "The best story ever told on film | \n", + "
3 | \n", + "1 September 2002 | \n", + "tt0111161 | \n", + "ur1003471 | \n", + "True | \n", + "**Yes, there are SPOILERS here**This film has ... | \n", + "10 | \n", + "Busy dying or busy living? | \n", + "
4 | \n", + "20 May 2004 | \n", + "tt0111161 | \n", + "ur0226855 | \n", + "True | \n", + "At the heart of this extraordinary movie is a ... | \n", + "8 | \n", + "Great story, wondrously told and acted | \n", + "
\n", + " | review_text | \n", + "is_spoiler | \n", + "
---|---|---|
0 | \n", + "In its Oscar year, Shawshank Redemption (writt... | \n", + "True | \n", + "
1 | \n", + "The Shawshank Redemption is without a doubt on... | \n", + "True | \n", + "
2 | \n", + "I believe that this film is the best story eve... | \n", + "True | \n", + "
3 | \n", + "**Yes, there are SPOILERS here**This film has ... | \n", + "True | \n", + "
4 | \n", + "At the heart of this extraordinary movie is a ... | \n", + "True | \n", + "
Model: \"sequential\"\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "\u001b[1mModel: \"sequential\"\u001b[0m\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", + "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", + "│ embedding (Embedding) │ (None, 2000, 32) │ 256,000 │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ bidirectional (Bidirectional) │ (None, 2000, 128) │ 49,664 │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ dropout (Dropout) │ (None, 2000, 128) │ 0 │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ bidirectional_1 (Bidirectional) │ (None, 128) │ 98,816 │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ dropout_1 (Dropout) │ (None, 128) │ 0 │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ dense (Dense) │ (None, 1) │ 129 │\n", + "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n", + "│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2000\u001b[0m, \u001b[38;5;34m32\u001b[0m) │ \u001b[38;5;34m256,000\u001b[0m │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ bidirectional (\u001b[38;5;33mBidirectional\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2000\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m49,664\u001b[0m │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m2000\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ bidirectional_1 (\u001b[38;5;33mBidirectional\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m98,816\u001b[0m │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ dropout_1 (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │\n", + "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n", + "│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m129\u001b[0m │\n", + "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Total params: 404,609 (1.54 MB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m404,609\u001b[0m (1.54 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 404,609 (1.54 MB)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m404,609\u001b[0m (1.54 MB)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Non-trainable params: 0 (0.00 B)\n", + "\n" + ], + "text/plain": [ + "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n", + "model.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sKMQE1YGZlyW", + "outputId": "5048b049-a722-4d30-99f4-68f82684db51" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1000s\u001b[0m 242ms/step - accuracy: 0.7216 - loss: 0.5929 - val_accuracy: 0.7301 - val_loss: 0.5770\n", + "Epoch 2/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1042s\u001b[0m 243ms/step - accuracy: 0.7339 - loss: 0.5535 - val_accuracy: 0.7552 - val_loss: 0.5142\n", + "Epoch 3/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1053s\u001b[0m 245ms/step - accuracy: 0.7586 - loss: 0.5097 - val_accuracy: 0.7632 - val_loss: 0.5039\n", + "Epoch 4/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1036s\u001b[0m 244ms/step - accuracy: 0.7700 - loss: 0.4951 - val_accuracy: 0.7632 - val_loss: 0.5038\n", + "Epoch 5/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1051s\u001b[0m 246ms/step - accuracy: 0.7813 - loss: 0.4803 - val_accuracy: 0.7639 - val_loss: 0.5083\n", + "Epoch 6/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1042s\u001b[0m 246ms/step - accuracy: 0.7862 - loss: 0.4705 - val_accuracy: 0.7631 - val_loss: 0.5041\n", + "Epoch 7/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1043s\u001b[0m 246ms/step - accuracy: 0.7963 - loss: 0.4563 - val_accuracy: 0.7644 - val_loss: 0.5167\n", + "Epoch 8/8\n", + "\u001b[1m4102/4102\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1043s\u001b[0m 247ms/step - accuracy: 0.8045 - loss: 0.4437 - val_accuracy: 0.7601 - val_loss: 0.5149\n" + ] + } + ], + "source": [ + "history = model.fit(X_train, y_train, epochs=8, batch_size=32, validation_data=(X_val, y_val))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "UnGyGSeo4K4H", + "outputId": "b7e5ee3f-938a-41ad-ce94-5cb6ebec26b7" + }, + "outputs": [ + { + "metadata": { + "tags": null + }, + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. \n" + ] + } + ], + "source": [ + "# Save the model in HDF5 forma\n", + "model.save('my_fine_tuned_model.h5')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "WeZv_N194s8z", + "outputId": "722c0f6a-edc8-462e-ce05-c20618833fc3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mounted at /content/drive\n" + ] + } + ], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "TmLn0V0D452u" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "model_path = '/content/drive/MyDrive/LSTMModel'\n", + "os.makedirs(model_path, exist_ok=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "-J6G6pG35DmR", + "outputId": "cd1e3364-1e0b-45c6-a1f6-65cdc71711e2" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`. \n" + ] + } + ], + "source": [ + "# Save the model in HDF5 format\n", + "model.save(os.path.join(model_path, 'my_fine_tuned_model.h5'))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "etkORlK83kd-" + }, + "outputs": [], + "source": [ + "!nvidia-smi\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "id": "MOlCZi49dw9l" + }, + "outputs": [], + "source": [ + "import gc\n", + "gc.collect()\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VPW3AaGFnAk4", + "outputId": "eddb477d-8a13-44fe-df8a-5bd0c5d70398" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m684/684\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m71s\u001b[0m 104ms/step - accuracy: 0.7551 - loss: 0.5165\n", + "Test Loss: 0.5159615874290466\n", + "Test Accuracy: 0.7574856877326965\n" + ] + } + ], + "source": [ + "# Evaluate the model on the test set\n", + "loss, accuracy = model.evaluate(X_test, y_test)\n", + "print(f\"Test Loss: {loss}\")\n", + "print(f\"Test Accuracy: {accuracy}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "hRSvT5cB5QMp", + "outputId": "64462d92-b4e0-4b03-be6b-42379bf65ca9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (1.26.4)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.3.2)\n", + "Requirement already satisfied: tensorflow in /usr/local/lib/python3.10/dist-packages (2.17.0)\n", + "Requirement already satisfied: scipy>=1.5.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.13.1)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n", + "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.4.0)\n", + "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.6.3)\n", + "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.3.25)\n", + "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.6.0)\n", + "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.2.0)\n", + "Requirement already satisfied: h5py>=3.10.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.11.0)\n", + "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (18.1.1)\n", + "Requirement already satisfied: ml-dtypes<0.5.0,>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.4.0)\n", + "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.3.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow) (24.1)\n", + "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.20.3)\n", + "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.32.3)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow) (71.0.4)\n", + "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n", + "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.4.0)\n", + "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (4.12.2)\n", + "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.16.0)\n", + "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (1.64.1)\n", + "Requirement already satisfied: tensorboard<2.18,>=2.17 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (2.17.0)\n", + "Requirement already satisfied: keras>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (3.4.1)\n", + "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow) (0.37.1)\n", + "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow) (0.44.0)\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (13.7.1)\n", + "Requirement already satisfied: namex in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (0.0.8)\n", + "Requirement already satisfied: optree in /usr/local/lib/python3.10/dist-packages (from keras>=3.2.0->tensorflow) (0.12.1)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow) (2024.7.4)\n", + "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.6)\n", + "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (0.7.2)\n", + "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.18,>=2.17->tensorflow) (3.0.3)\n", + "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.18,>=2.17->tensorflow) (2.1.5)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.2.0->tensorflow) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.2.0->tensorflow) (2.16.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.2.0->tensorflow) (0.1.2)\n" + ] + } + ], + "source": [ + "!pip install numpy scikit-learn tensorflow\n" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "T6uO6ZVX50JF", + "outputId": "dd48288d-9706-4e15-acd6-542171bafae4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model.\n" + ] + } + ], + "source": [ + "from tensorflow.keras.models import load_model\n", + "model = load_model('my_fine_tuned_model.h5')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "id": "E-YKdjrE6I34" + }, + "outputs": [], + "source": [ + "from sklearn.metrics import classification_report\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "XVxaRfFB6XVo", + "outputId": "a3f35aa5-225d-47a7-97f2-3ef3ed3b97ad" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m684/684\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m80s\u001b[0m 117ms/step\n" + ] + } + ], + "source": [ + "predictions = model.predict(X_test)\n", + "predictions = (predictions > 0.5).astype(int) # Threshold predictions for binary classification\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Vy-eg8JG6e1j", + "outputId": "69cb6c8e-4837-417a-81f2-6a8aec773ef1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " Non-Spoiler 0.79 0.91 0.84 15695\n", + " Spoiler 0.62 0.38 0.47 6180\n", + "\n", + " accuracy 0.76 21875\n", + " macro avg 0.70 0.64 0.66 21875\n", + "weighted avg 0.74 0.76 0.74 21875\n", + "\n" + ] + } + ], + "source": [ + "report = classification_report(y_test, predictions, target_names=['Non-Spoiler', 'Spoiler'])\n", + "print(report)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4sHnLnuQ7BNS", + "outputId": "f053ae15-8842-42f0-b7df-ef63914f85f9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (1.3.2)\n", + "Requirement already satisfied: numpy<2.0,>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.26.4)\n", + "Requirement already satisfied: scipy>=1.5.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.13.1)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn) (3.5.0)\n" + ] + } + ], + "source": [ + "!pip install scikit-learn" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "id": "5-c0qcYn8GqF" + }, + "outputs": [], + "source": [ + "def prepare_input(text, tokenizer, max_length):\n", + " sequences = tokenizer.texts_to_sequences([text])\n", + " padded = pad_sequences(sequences, maxlen=max_length, padding='post')\n", + " return padded\n", + "\n", + "def predict_spoiler(text, tokenizer, model, max_length=1500):\n", + " prepared_text = prepare_input(text, tokenizer, max_length)\n", + " prediction = model.predict(prepared_text)\n", + " is_spoiler = (prediction > 0.5).astype(int) # Threshold the prediction\n", + " return bool(is_spoiler)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Unh_FyLx8Ith", + "outputId": "4d7b676f-bcb3-4ffb-f9f9-97d6e54f9a44" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 966ms/step\n", + "The text is a non-spoiler.\n" + ] + } + ], + "source": [ + "# Test the function\n", + "input_text = \"Jack Ryan is on a working vacation in London with his family. He has retired from the CIA and is a Professor at the US Naval Academy. He is seen delivering a lecture at the Royal Naval Academy in London.Meanwhile, Ryan's wife Cathy and daughter Sally are sightseeing near Buckingham Palace. Sally and Cathy come upon a British Royal Guard, and Sally tries to get the guard to react by doing an improvised tap dance in front of him. She's impressed when the guard, trained to ignore distraction, doesn't react at all, and they leave.\"\n", + "is_spoiler = predict_spoiler(input_text, tokenizer, model)\n", + "print(f\"The text is a {'spoiler' if is_spoiler else 'non-spoiler'}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2bjajseSZpce", + "outputId": "fd4681b2-a612-4453-bbc3-2e3552e67942" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 406ms/step\n", + "The text is a non-spoiler.\n" + ] + } + ], + "source": [ + "# Test the function\n", + "input_text = \"Three years have passed since John Brennan (Russel Crowe) and his wife Lara (Elizabeth Banks) lost their son Luke (Tyler Simpkins) in a car accident. Three years later, John is a community college teacher who is teaching English. He tries to manage his job and raising Luke.\"\n", + "is_spoiler = predict_spoiler(input_text, tokenizer, model)\n", + "print(f\"The text is a {'spoiler' if is_spoiler else 'non-spoiler'}.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "26-LoGDzZ6P1" + }, + "outputs": [], + "source": [ + "# Test the function\n", + "input_text = \"Three years have passed since John Brennan (Russel Crowe) and his wife Lara (Elizabeth Banks) lost their son Luke (Tyler Simpkins) in a car accident. Three years later, John is a community college teacher who is teaching English. He tries to manage his job and raising Luke.\"\n", + "is_spoiler = predict_spoiler(input_text, tokenizer, model)\n", + "print(f\"The text is a {'spoiler' if is_spoiler else 'non-spoiler'}.\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}