{ "cells": [ { "cell_type": "markdown", "id": "debeec92", "metadata": { "tags": [] }, "source": [ "## Gathering NER Dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "b83776e8", "metadata": { "tags": [] }, "outputs": [], "source": [ "from datasets import DatasetDict\n", "from transformers import AutoTokenizer\n", "\n", "dataset = DatasetDict.load_from_disk().remove_columns([\"token_type_ids\", \"attention_mask\"])\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"./../tokenizer\")\n", "tokenizer.pad_token_id = 0\n", "tokenizer.pad_token = \"<|padding|>\"\n", "tokenizer.padding_size = \"right\"\n", "\n", "# new tokens for prompting\n", "num_new_tokens = tokenizer.add_tokens([\"<|startofprompt|>\", \"<|sepofprompt|>\", \"<|endofprompt|>\"])\n", "# new tokens for entities\n", "tokenizer.add_tokens([\"<|entity:PER|>\", \"<|entity:LOC|>\", \"<|entity:ORG|>\", \"<|entity|>\", \"<|detectentities|>\"])\n", "# new tokens for images\n", "tokenizer.add_tokens([\"<|startofimage|>\", \"<|endofimage|>\"])\n", "tokenizer.add_tokens([ f\"<|image:{tkn}|>\" for tkn in range(16000)])\n", "\n", "tokenizer.save_pretrained(\"./tokenizer\")\n", "\n", "print(\"Total Vocab Size:\", len(tokenizer))" ] }, { "cell_type": "code", "execution_count": null, "id": "f2a95871-6e2d-4b96-bc36-8febac09d795", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"./tokenizer\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a706dd6d-e9b2-4e42-baf1-7d17cd93c54f", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "import string\n", "import os\n", "import re\n", "\n", "audio_paths = sorted(os.listdir(\"./mp3\"))\n", "txt_paths = sorted(os.listdir(\"./txt\"))\n", "data = np.load(\"tokens.npz\")\n", "audio_tokens = [data[key] for key in data.keys()]" ] }, { "cell_type": "code", "execution_count": null, "id": "ce8bb550-8149-438c-9ca5-b12681f36476", "metadata": { "tags": [] }, "outputs": [], "source": [ "def tag_entities(text):\n", " \n", " patterns = {\n", " \"PER\": r'\\|(.*?)\\]',\n", " \"LOC\": r'\\$(.*?)\\]',\n", " \"ORG\": r'\\{(.*?)\\]'\n", " }\n", " \n", " entities = []\n", "\n", " for entity, pattern in patterns.items():\n", " matches = re.findall(pattern, text)\n", " text = re.sub(pattern, lambda m: f'<|entity:{entity}|>{m.group(1)}<|entity|>', text)\n", " entities += matches\n", "\n", " return text, entities\n", "\n", "data = []\n", "\n", "for idx in tqdm(range(len(txt_paths))):\n", " \n", " with open(os.path.join(\"./txt\", txt_paths[idx])) as f:\n", " txt = f.read()\n", " \n", " text, entities = tag_entities(txt.lower())\n", " \n", " audio_token = audio_tokens[idx]\n", " \n", " prompt = \"\".join([f\"<|audio:{tkn}|>\" for tkn in audio_token]) + \"<|detectentities|><|startofprompt|><|endofprompt|>\" + \"<|startoftranscript|>\" + text + \"<|endoftranscript|>\"\n", " \n", " try:\n", " outputs = tokenizer(prompt, truncation=True, padding=\"max_length\", max_length=2048)\n", " data.append({\n", " \"audio_tokens\": audio_token,\n", " \"raw_text\": text,\n", " \"transcript\": txt.translate(str.maketrans('', '', string.punctuation)).lower(),\n", " \"entities\": entities,\n", " \"prompt\": prompt,\n", " \"input_ids\": outputs[\"input_ids\"],\n", " \"attention_mask\": output[\"attention_mask\"]\n", " })\n", " except:\n", " print(idx)\n", " continue\n", " \n", "from datasets import Dataset\n", "import pandas as pd\n", "\n", "ds = Dataset.from_pandas(pd.DataFrame(data))\n", "\n", "ds.save_to_disk(\"entity_tokenized\")\n", "ds.push_to_hub(\"darshanmakwana/entity_tokenized\")" ] }, { "cell_type": "markdown", "id": "38191f9a-2a11-4bb2-a885-ef303d6c43f7", "metadata": { "tags": [] }, "source": [ "## Validating Model" ] }, { "cell_type": "code", "execution_count": 11, "id": "710a1144-46a1-43d4-9bf9-1c01569b26d4", "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import GPT2LMHeadModel, AutoTokenizer\n", "from datasets import Dataset\n", "import torch\n", "\n", "dataset_name = \"entity_tokenized\"\n", "tokenizer_path = \"./../tokenizer\"\n", "max_length = 2048\n", "device = \"cuda:0\"\n", "dtype = torch.float16\n", "\n", "dataset = Dataset.load_from_disk(dataset_name)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)\n", "tokenizer.pad_token_id = 0\n", "tokenizer.pad_token = \"<|padding|>\"\n", "tokenizer.padding_side = \"left\"\n", "\n", "# new tokens for prompting\n", "num_new_tokens = tokenizer.add_tokens([\"<|startofprompt|>\", \"<|sepofprompt|>\", \"<|endofprompt|>\"])\n", "# new tokens for entities\n", "tokenizer.add_tokens([\"<|entity:PER|>\", \"<|entity:LOC|>\", \"<|entity:ORG|>\", \"<|entity|>\", \"<|detectentities|>\"])\n", "\n", "model = GPT2LMHeadModel.from_pretrained(\"./out/checkpoint-20000\").to(device).to(dtype).eval()" ] }, { "cell_type": "code", "execution_count": 21, "id": "cea0d8c4-5c56-47eb-934a-86293bed6afa", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "114.073974609375" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sum([param.numel() for param in model.parameters()]) / (1024 * 1024)" ] }, { "cell_type": "code", "execution_count": 12, "id": "529ca732-569f-4b7d-8448-1f16b35a6694", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:27<00:00, 3.42s/it]\n" ] } ], "source": [ "from eval_model import process\n", "from math import ceil\n", "from tqdm import tqdm\n", "import re\n", "\n", "def extract_entities(text):\n", " \n", " patterns = {\n", " \"PER\": r'<\\|entity:PER\\|>(.*?)<\\|entity\\|>',\n", " \"LOC\": r'<\\|entity:LOC\\|>(.*?)<\\|entity\\|>',\n", " \"ORG\": r'<\\|entity:ORG\\|>(.*?)<\\|entity\\|>'\n", " }\n", " \n", " entities = []\n", "\n", " for entity, pattern in patterns.items():\n", " matches = re.findall(pattern, text)\n", " text = re.sub(pattern, lambda m: f'{m.group(1)}', text)\n", " entities += [process(match) for match in matches]\n", "\n", " return text, entities\n", "\n", "def preprocess(sample):\n", " prompt = \"\".join([f\"<|audio:{tkn}|>\" for tkn in sample[\"audio_tokens\"]]) + \"<|detectentities|><|startofprompt|><|endofprompt|>\" + \"<|startoftranscript|>\"\n", " return {\"prompt\": prompt}\n", "\n", "dataset = dataset.map(preprocess)\n", "dataset = dataset.select(list(range(0, 1000)))\n", "\n", "eot_token = tokenizer.encode(\"<|endoftranscript|>\")[0]\n", "\n", "batch_size = 128\n", "texts = []\n", "tp = 0\n", "fp = 0\n", "tn = 0\n", "\n", "for idx in tqdm(range(ceil(len(dataset)/batch_size))):\n", "\n", " input_ids = tokenizer(dataset[idx * batch_size: (idx + 1) * batch_size][\"prompt\"], return_tensors=\"pt\", padding=True, truncation=True).input_ids.to(model.device)\n", " par = input_ids.shape[-1]\n", "\n", " generations = model.generate(\n", " input_ids,\n", " max_new_tokens=max_length,\n", " eos_token_id = eot_token\n", " )\n", " texts += tokenizer.batch_decode(generations[:, par:], skip_special_tokens=True)\n", "\n", "# transcript, pred_entities = extract_entities(transcripts[0])\n", " \n", "# entities = sample[\"entities\"]" ] }, { "cell_type": "code", "execution_count": 13, "id": "5ce4384e-8771-487e-86e1-de5489ee4e59", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 241.04it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Precision: 69.53846153846153\n", "Recall: 69.32515337423312\n", "F1 Score: 69.43164362519201\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "tp = 0\n", "fp = 0\n", "fn = 0\n", "\n", "for idx in tqdm(range(len(dataset))):\n", " \n", " transcript, entities = extract_entities(texts[idx])\n", "\n", " for entity in entities:\n", " if entity in dataset[idx][\"entities\"]:\n", " tp += 1\n", " else:\n", " fp += 1\n", " for entity in dataset[idx][\"entities\"]:\n", " if entity not in entities:\n", " fn += 1\n", " \n", "pre = tp / (tp + fp) * 100\n", "recall = tp / (tp + fn) * 100\n", "print(\"Precision:\", pre)\n", "print(\"Recall:\", recall)\n", "print(\"F1 Score:\", 2 / ((1/pre) + (1/recall)))" ] }, { "cell_type": "code", "execution_count": null, "id": "ed0fad1a-bb30-446e-83a9-4a972fdb7766", "metadata": { "tags": [] }, "outputs": [], "source": [ "## Train Iter Precision Recall F1 Score\n", " 16000 68.80 69.27 69.03\n", " 17000 72.92 70.78 71.83\n", " 18000 76.78 75.34 76.05\n", " 19000 81.78 80.92 81.34\n", " 20000 85.05 80.74 82.84" ] }, { "cell_type": "code", "execution_count": 16, "id": "113df077-c31c-4b57-876e-b19942100306", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "81.34772710510141" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "2 / ((1/81.78) + (1/80.92))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }