{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9e85b4fd-6c00-4d15-9a99-f461461bf660", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in /home/p_babro/miniconda3/lib/python3.12/site-packages (4.43.4)\n", "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (3.15.4)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.24.5)\n", "Requirement already satisfied: numpy>=1.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2024.7.24)\n", "Requirement already satisfied: requests in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2.32.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.4.4)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (4.66.4)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.5.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.2.2)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2024.7.4)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: datasets in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.20.0)\n", "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.15.4)\n", "Requirement already satisfied: numpy>=1.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (1.26.4)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (17.0.0)\n", "Requirement already satisfied: pyarrow-hotfix in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (2.2.2)\n", "Requirement already satisfied: requests>=2.32.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (2.32.2)\n", "Requirement already satisfied: tqdm>=4.66.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (4.66.4)\n", "Requirement already satisfied: xxhash in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n", "Requirement already satisfied: aiohttp in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.10.1)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.24.5)\n", "Requirement already satisfied: packaging in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (2.3.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (24.1.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub>=0.21.2->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2.2.2)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests>=2.32.2->datasets) (2024.7.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: sentencepiece in /home/p_babro/miniconda3/lib/python3.12/site-packages (0.2.0)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Requirement already satisfied: pandas in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.2.2)\n", "Requirement already satisfied: openpyxl in /home/p_babro/miniconda3/lib/python3.12/site-packages (3.1.5)\n", "Requirement already satisfied: numpy>=1.26.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (1.26.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: et-xmlfile in /home/p_babro/miniconda3/lib/python3.12/site-packages (from openpyxl) (1.1.0)\n", "Requirement already satisfied: six>=1.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install transformers\n", "%pip install datasets\n", "%pip install sentencepiece\n", "%pip install pandas openpyxl" ] }, { "cell_type": "code", "execution_count": 2, "id": "f72773a5-ddbc-43f7-a0b8-7b004a8b0db6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " labels text\n", "0 1 Strach z osobního selhání často v kritických o...\n", "1 5 Pre týchto ľudí treba nájsť riešenie.\n", "2 5 Čestnými hosty byli bývalý spolkový prezident ...\n", "3 4 Vaše milá slova mi opravdu zlepšila den.\n", "4 4 Ďakujem mnohokrát! Z pochvaly máme radosť.\n" ] } ], "source": [ "import pandas as pd\n", "\n", "# Specify the file path\n", "file_path = '/project/home/p_babro/p_babel/v4_slant/pooled_v4_xlmRoberta_training.xlsx'\n", "\n", "# Read the Excel file\n", "df = pd.read_excel(file_path)\n", "\n", "# Display the DataFrame\n", "print(df.head())" ] }, { "cell_type": "code", "execution_count": 6, "id": "e8c9c696-9308-4ac1-8364-798c04e7b54a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['labels', 'text'], dtype='object')\n" ] } ], "source": [ "# Load data from Excel file\n", "df = pd.read_excel(file_path)\n", "\n", "# Print the column names to verify\n", "print(df.columns)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "86d92b6f-03b0-4df2-8f48-a34185180662", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification\n", "\n", "# Model and tokenizer initialization\n", "tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')\n", "model = XLMRobertaForSequenceClassification.from_pretrained('xlm-roberta-base')" ] }, { "cell_type": "code", "execution_count": 9, "id": "24f34a63-31e4-4b57-bc72-a635cf3297a2", "metadata": {}, "outputs": [], "source": [ "def start_train(df, model_name, batch_size, lr, max_length, num_epochs):\n", "\n", " # Prepare labels\n", " label_encoder = LabelEncoder()\n", " labels = df[label_column]\n", " labels = label_encoder.fit_transform(labels)\n", " num_labels = len(set(labels))\n", "\n", " # Hugging Face Datasets format\n", " train_dataset = Dataset.from_pandas(train_data)\n", " val_dataset = Dataset.from_pandas(val_data)\n", " test_dataset = Dataset.from_pandas(test_data)\n", "\n", " # Load tokenizer\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " # Tokenize\n", " train_dataset = train_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=train_dataset.column_names)\n", " val_dataset = val_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=val_dataset.column_names)\n", " test_dataset = test_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=test_dataset.column_names)\n", "\n", " # Load model\n", " model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, problem_type=\"multi_label_classification\")\n", "\n", " # Training arguments\n", " training_args = TrainingArguments(\n", " output_dir=drive_folder_to_save,\n", " logging_dir=drive_folder_to_save,\n", " logging_strategy='epoch',\n", " logging_steps=100,\n", " num_train_epochs=num_epochs,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " learning_rate=lr,\n", " seed=42,\n", " save_strategy='epoch',\n", " save_steps=100,\n", " evaluation_strategy='epoch',\n", " eval_steps=100,\n", " save_total_limit=1,\n", " load_best_model_at_end=True,\n", " )\n", "\n", " # Create trainer\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " compute_metrics=compute_metrics,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n", " )\n", "\n", " # Train model\n", " trainer.train()\n", "\n", " # Evaluate results\n", " predictions = trainer.predict(test_dataset).predictions\n", " preds = np.argmax(predictions, axis=1)\n", " accuracy = accuracy_score(test_data[label_column], preds)\n", " print(f'Accuracy: {accuracy}')\n", " precision, recall, f1, _ = precision_recall_fscore_support(test_data[label_column], preds, average='weighted')\n", " print(f'Accuracy: {accuracy}')\n", " print(f'Precision: {precision}')\n", " print(f'Recall: {recall}')\n", " print(f'F1 Score: {f1}')\n", "\n", " # Save model\n", " trainer.save_model(folder_to_save)\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "669ef024-3b2c-47c3-954c-de1e2b50f1d6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: pandas in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.2.2)\n", "Requirement already satisfied: openpyxl in /home/p_babro/miniconda3/lib/python3.12/site-packages (3.1.5)\n", "Requirement already satisfied: transformers in /home/p_babro/miniconda3/lib/python3.12/site-packages (4.43.4)\n", "Requirement already satisfied: datasets in /home/p_babro/miniconda3/lib/python3.12/site-packages (2.20.0)\n", "Requirement already satisfied: evaluate in /home/p_babro/miniconda3/lib/python3.12/site-packages (0.4.2)\n", "Requirement already satisfied: scikit-learn in /home/p_babro/miniconda3/lib/python3.12/site-packages (1.5.1)\n", "Requirement already satisfied: numpy>=1.26.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (1.26.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from pandas) (2024.1)\n", "Requirement already satisfied: et-xmlfile in /home/p_babro/miniconda3/lib/python3.12/site-packages (from openpyxl) (1.1.0)\n", "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (3.15.4)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.24.5)\n", "Requirement already satisfied: packaging>=20.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2024.7.24)\n", "Requirement already satisfied: requests in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2.32.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.4.4)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (4.66.4)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (17.0.0)\n", "Requirement already satisfied: pyarrow-hotfix in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: xxhash in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets) (2024.5.0)\n", "Requirement already satisfied: aiohttp in /home/p_babro/miniconda3/lib/python3.12/site-packages (from datasets) (3.10.1)\n", "Requirement already satisfied: scipy>=1.6.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from scikit-learn) (1.14.0)\n", "Requirement already satisfied: joblib>=1.2.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from scikit-learn) (1.4.2)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from scikit-learn) (3.5.0)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (2.3.4)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (24.1.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", "Requirement already satisfied: six>=1.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.2.2)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2024.7.4)\n", "Note: you may need to restart the kernel to use updated packages.\n", "Train data shape: (186137, 2)\n", "Val data shape: (23267, 2)\n", "Test data shape: (23268, 2)\n", "/project/home/p_babro/p_babel/v4_slant/test_data.xlsx saved!\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "939855a37b3b43f3a1b5a54f3b7a1031", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/186137 [00:00\n", " \n", " \n", " [ 58170/116340 2:33:12 < 2:33:12, 6.33 it/s, Epoch 5/10]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyPrecisionRecallF1
10.1773000.1418490.8172520.8189180.8172520.817750
20.1345000.1333380.8301030.8306760.8301030.830280
30.1201000.1300690.8342290.8345280.8342290.833342
40.1106000.1329420.8350450.8347900.8350450.834567
50.1032000.1312410.8334550.8336050.8334550.833047

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "58682414ea364b57bb8cf08b0df06e4f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading builder script: 0%| | 0.00/4.20k [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.8367715317173801\n", "Precision: 0.8369187930273877\n", "Recall: 0.8367715317173801\n", "F1 Score: 0.8360611942926541\n" ] } ], "source": [ "# Install necessary libraries\n", "%pip install pandas openpyxl transformers datasets evaluate scikit-learn\n", "\n", "# Import necessary libraries\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n", "from transformers import (XLMRobertaTokenizer, XLMRobertaForSequenceClassification, AutoTokenizer,\n", " AutoModelForSequenceClassification, Trainer, TrainingArguments)\n", "from datasets import Dataset\n", "from transformers.trainer_callback import EarlyStoppingCallback\n", "import evaluate\n", "from typing import List, Tuple\n", "\n", "# Define paths and columns\n", "file_path = '/project/home/p_babro/p_babel/v4_slant/pooled_v4_xlmRoberta_training.xlsx'\n", "text_column = 'text' # Replace with your actual text column name\n", "label_column = 'labels' # Replace with your actual label column name\n", "drive_folder_to_save = '/project/home/p_babro/p_babel/v4_slant' # Replace with your actual save folder path\n", "\n", "# Define functions\n", "def load_data_from_excel(df, text_column: str, label_column: str) -> Tuple[List, List]:\n", " return df[text_column].tolist(), df[label_column].tolist()\n", "\n", "def tokenize_dataset(data, tokenizer, max_length, num_labels):\n", " tokenized = tokenizer(data[text_column],\n", " max_length=max_length,\n", " truncation=True,\n", " padding=\"max_length\")\n", "\n", " labels = [x for x in data[label_column]]\n", " labels_tensor = torch.as_tensor(labels)\n", " labels_binary = torch.nn.functional.one_hot(labels_tensor, num_classes=num_labels).float()\n", "\n", " tokenized['labels'] = labels_binary\n", "\n", " return tokenized\n", "\n", "def compute_metrics(eval_pred):\n", " metric = evaluate.load(\"accuracy\")\n", " logits, labels = eval_pred\n", " predictions = np.argmax(logits, axis=1)\n", " reference_labels = [np.argmax(label) for label in labels]\n", " precision, recall, f1, _ = precision_recall_fscore_support(reference_labels, predictions, average='weighted')\n", " accuracy = accuracy_score(reference_labels, predictions)\n", " return {\n", " 'accuracy': accuracy,\n", " 'precision': precision,\n", " 'recall': recall,\n", " 'f1': f1\n", " }\n", "\n", "# Load data from Excel file\n", "df = pd.read_excel(file_path)\n", "texts, labels = load_data_from_excel(df, text_column, label_column)\n", "\n", "# Split the data\n", "data = pd.DataFrame({text_column: texts, label_column: labels})\n", "train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)\n", "val_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)\n", "\n", "print(f'Train data shape: {train_data.shape}')\n", "print(f'Val data shape: {val_data.shape}')\n", "print(f'Test data shape: {test_data.shape}')\n", "\n", "# Save test data to Excel\n", "test_data.to_excel(f'{drive_folder_to_save}/test_data.xlsx', index=False)\n", "print(f'{drive_folder_to_save}/test_data.xlsx saved!')\n", "\n", "def start_train(df, model_name, batch_size, lr, max_length, num_epochs):\n", "\n", " # Prepare labels\n", " label_encoder = LabelEncoder()\n", " labels = df[label_column]\n", " labels = label_encoder.fit_transform(labels)\n", " num_labels = len(set(labels))\n", "\n", " # Hugging Face Datasets format\n", " train_dataset = Dataset.from_pandas(train_data)\n", " val_dataset = Dataset.from_pandas(val_data)\n", " test_dataset = Dataset.from_pandas(test_data)\n", "\n", " # Load tokenizer\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", " # Tokenize\n", " train_dataset = train_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=train_dataset.column_names)\n", " val_dataset = val_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=val_dataset.column_names)\n", " test_dataset = test_dataset.map(lambda x: tokenize_dataset(x, tokenizer, max_length, num_labels), batched=True, remove_columns=test_dataset.column_names)\n", "\n", " # Load model\n", " model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, problem_type=\"multi_label_classification\")\n", "\n", " # Training arguments\n", " training_args = TrainingArguments(\n", " output_dir=drive_folder_to_save,\n", " logging_dir=drive_folder_to_save,\n", " logging_strategy='epoch',\n", " logging_steps=100,\n", " num_train_epochs=num_epochs,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " learning_rate=lr,\n", " seed=42,\n", " save_strategy='epoch',\n", " save_steps=100,\n", " evaluation_strategy='epoch',\n", " eval_steps=100,\n", " save_total_limit=1,\n", " load_best_model_at_end=True,\n", " )\n", "\n", " # Create trainer\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " compute_metrics=compute_metrics,\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n", " )\n", "\n", " # Train model\n", " trainer.train()\n", "\n", " # Evaluate results\n", " predictions = trainer.predict(test_dataset).predictions\n", " preds = np.argmax(predictions, axis=1)\n", " accuracy = accuracy_score(test_data[label_column], preds)\n", " print(f'Accuracy: {accuracy}')\n", " precision, recall, f1, _ = precision_recall_fscore_support(test_data[label_column], preds, average='weighted')\n", " print(f'Precision: {precision}')\n", " print(f'Recall: {recall}')\n", " print(f'F1 Score: {f1}')\n", "\n", " # Save model\n", " trainer.save_model(drive_folder_to_save)\n", "\n", "# Define training parameters\n", "model_name = 'xlm-roberta-base'\n", "batch_size = 16\n", "learning_rate = 5e-6\n", "max_length = 128\n", "num_epochs = 10\n", "\n", "# Start training\n", "start_train(df, model_name, batch_size, learning_rate, max_length, num_epochs)\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "b47790d8-771e-45b9-a5c7-31d939de35b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: transformers in /home/p_babro/miniconda3/lib/python3.12/site-packages (4.43.4)\n", "Requirement already satisfied: huggingface_hub in /home/p_babro/miniconda3/lib/python3.12/site-packages (0.24.5)\n", "Collecting python-dotenv\n", " Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)\n", "Requirement already satisfied: filelock in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (3.15.4)\n", "Requirement already satisfied: numpy>=1.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (23.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (6.0.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2024.7.24)\n", "Requirement already satisfied: requests in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (2.32.2)\n", "Requirement already satisfied: safetensors>=0.4.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.4.4)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from transformers) (4.66.4)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface_hub) (2024.5.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from huggingface_hub) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2.2.2)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/p_babro/miniconda3/lib/python3.12/site-packages (from requests->transformers) (2024.7.4)\n", "Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)\n", "Installing collected packages: python-dotenv\n", "Successfully installed python-dotenv-1.0.1\n", "Note: you may need to restart the kernel to use updated packages.\n" ] }, { "ename": "ValueError", "evalue": "Please set the HF_TOKEN environment variable.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[12], line 16\u001b[0m\n\u001b[1;32m 14\u001b[0m hf_token \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mgetenv(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHF_TOKEN\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m hf_token:\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease set the HF_TOKEN environment variable.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 18\u001b[0m \u001b[38;5;66;03m# Define your save directory and Hugging Face repository information\u001b[39;00m\n\u001b[1;32m 19\u001b[0m drive_folder_to_save \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/project/home/p_babro/p_babel/v4_slant\u001b[39m\u001b[38;5;124m'\u001b[39m\n", "\u001b[0;31mValueError\u001b[0m: Please set the HF_TOKEN environment variable." ] } ], "source": [ "# Install necessary libraries\n", "%pip install transformers huggingface_hub python-dotenv\n", "\n", "# Import necessary libraries\n", "from transformers import AutoTokenizer\n", "from huggingface_hub import HfApi\n", "import os\n", "from dotenv import load_dotenv\n", "\n", "# Load environment variables from .env file\n", "load_dotenv()\n", "\n", "# Retrieve the token from the environment variable\n", "hf_token = os.getenv(\"HF_TOKEN\")\n", "if not hf_token:\n", " raise ValueError(\"Please set the HF_TOKEN environment variable.\")\n", "\n", "# Define your save directory and Hugging Face repository information\n", "drive_folder_to_save = '/project/home/p_babro/p_babel/v4_slant'\n", "repo_id = \"ringorsolya/Emotion_RoBERTa_pooled_V4\"\n", "\n", "# Set environment variable to avoid the parallelism warning\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "\n", "# Initialize the HfApi with your token\n", "api = HfApi()\n", "\n", "# Ensure the folder exists and contains files\n", "if os.path.exists(drive_folder_to_save) and os.listdir(drive_folder_to_save):\n", " print(f\"Uploading folder {drive_folder_to_save} to Hugging Face repository {repo_id}\")\n", " \n", " # Upload the model folder to the Hugging Face repository\n", " api.upload_folder(\n", " folder_path=drive_folder_to_save,\n", " repo_id=repo_id,\n", " token=hf_token\n", " )\n", " \n", " print(\"Folder upload completed.\")\n", "else:\n", " print(f\"The folder {drive_folder_to_save} does not exist or is empty.\")\n", "\n", "# Load the tokenizer (use the correct model name if different)\n", "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-base\") # Or the name of your saved model\n", "\n", "# Push the tokenizer to the Hugging Face repository\n", "tokenizer.push_to_hub(\n", " repo_id=repo_id,\n", " use_auth_token=hf_token\n", ")\n", "\n", "print(\"Tokenizer upload completed.\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "5bcd6e0f-f56f-4d6f-a323-04286e7d06f8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }