{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "H1UloQj623Ik" }, "source": [ "## Model Evaluation\n", "\n", "Hi, there welcome to my notebook! 👋\n", "\n", "This notebook is all about evaluating different models using a small subset of a larger Dataset.\n", "\n", "This Notebook is self contained meaning that expect for installing necessary libraries you can run all cells in order and everything should work\n", "If not, feel free to leave me a message and i'll give my best to fix the issue\n", "\n", "All you need for this notebook to work is a **HuggingFace token**.\n", "\n", "If you don't know how to find it.\n", "\n", "Go to your Hugging Face\n", "> Profile -> Settings -> Access Tokens -> + Create new token\n", "\n", "You can find the Notebook in Google Colab [here](https://colab.research.google.com/drive/1awfo4_Llrg-aypEc_MdJXcqQMj3r_Fy2?usp=share_link)" ] }, { "cell_type": "markdown", "metadata": { "id": "hDqZY8i85pOj" }, "source": [ "### 1. Import all necessary libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iw-5LI1u2x7a" }, "outputs": [], "source": [ "from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor\n", "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n", "from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC\n", "from huggingface_hub import login\n", "from datasets import load_dataset\n", "from datasets import Audio\n", "from tqdm import tqdm\n", "import evaluate\n", "import torch" ] }, { "cell_type": "markdown", "metadata": { "id": "gc4FRXzm5oTt" }, "source": [ "### 2. Log in & set constants" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6qTB32KR56lK" }, "outputs": [], "source": [ "# Login\n", "login(\"hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\")\n", "\n", "# Set constants\n", "N_SAMPLES = 100" ] }, { "cell_type": "markdown", "metadata": { "id": "vdZmlee66ItN" }, "source": [ "### 3. Load Dataset & Metric" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u4MDh9HA6QwF" }, "outputs": [], "source": [ "# Load the Dataset\n", "dataset = load_dataset(\"librispeech_asr\", \"clean\", split=\"test\", streaming=True, token=True, trust_remote_code=True)\n", "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n", "dataset = dataset.take(N_SAMPLES)\n", "\n", "# Load the Evaluation Metric\n", "wer_metric = evaluate.load(\"wer\")\n", "\n", "# Create Dictionary to Store Results\n", "results = {\n", " \"facebook/wav2vec2-base-960h\":0,\n", " \"openai/whisper-tiny.en\":0,\n", " \"facebook/s2t-medium-librispeech-asr\":0\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "JDRzDiZ86XEa" }, "source": [ "### 4. Evaluate the first Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tNWLJ6bp6bnc" }, "outputs": [], "source": [ "# Load the 1. ASR Model\n", "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n", "model = Wav2Vec2ForCTC.from_pretrained(\"facebook/wav2vec2-base-960h\")\n", "\n", "\n", "# Run Inference For the First Model\n", "predictions = []\n", "references = []\n", "\n", "for i, item in tqdm(enumerate(dataset), total=N_SAMPLES):\n", " input_values = processor(item[\"audio\"][\"array\"], sampling_rate=16000, return_tensors=\"pt\", padding=\"longest\").input_values # Batch size 1\n", " logits = model(input_values).logits\n", " predicted_ids = torch.argmax(logits, dim=-1)\n", " transcription = processor.batch_decode(predicted_ids)\n", " predictions.append(transcription[0])\n", " references.append(item[\"text\"])\n", "\n", "\n", "\n", "wer = wer_metric.compute(references=references, predictions=predictions)\n", "wer = round(100 * wer, 2)\n", "print(\"WER:\", wer)\n", "results[\"facebook/wav2vec2-base-960h\"] = wer" ] }, { "cell_type": "markdown", "metadata": { "id": "LObMf9h-6eo_" }, "source": [ "### 5. Evaluate the second Model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kslHlHA86okx" }, "outputs": [], "source": [ "# Load the 2. ASR Model\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny.en\")\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny.en\")\n", "\n", "\n", "# Run Inference For the First Model\n", "predictions = []\n", "references = []\n", "\n", "for i, item in tqdm(enumerate(dataset), total=N_SAMPLES):\n", " input_features = processor(item[\"audio\"][\"array\"], sampling_rate=16000, return_tensors=\"pt\", padding=\"longest\").input_features # Batch size 1\n", " predicted_ids = model.generate(input_features=input_features)\n", " transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)\n", " predictions.append(processor.tokenizer.normalize(transcription[0]))\n", " references.append(processor.tokenizer.normalize(item[\"text\"]))\n", "\n", "\n", "\n", "wer = wer_metric.compute(references=references, predictions=predictions)\n", "wer = round(100 * wer, 2)\n", "print(\"WER:\", wer)\n", "results[\"openai/whisper-tiny.en\"] = wer" ] }, { "cell_type": "markdown", "metadata": { "id": "VXKxHUFi6puQ" }, "source": [ "### 6. Evaluate the third Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mKQgkwnf6vVM" }, "outputs": [], "source": [ "# Load the 3. ASR Model\n", "model = Speech2TextForConditionalGeneration.from_pretrained(\"facebook/s2t-medium-librispeech-asr\")\n", "processor = Speech2TextProcessor.from_pretrained(\"facebook/s2t-medium-librispeech-asr\", do_upper_case=True)\n", "\n", "\n", "# Run Inference For the First Model\n", "predictions = []\n", "references = []\n", "\n", "for i, item in tqdm(enumerate(dataset), total=N_SAMPLES):\n", " sample = item[\"audio\"]\n", " features = processor(sample[\"array\"], sampling_rate=16000, padding=True, return_tensors=\"pt\")\n", " input_features = features.input_features\n", " attention_mask = features.attention_mask\n", " gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)\n", " transcription= processor.batch_decode(gen_tokens, skip_special_tokens=True)\n", " predictions.append(transcription[0])\n", " references.append(item[\"text\"])\n", "\n", "\n", "\n", "wer = wer_metric.compute(references=references, predictions=predictions)\n", "wer = round(100 * wer, 2)\n", "print(\"WER:\", wer)\n", "results[\"facebook/s2t-medium-librispeech-asr\"] = wer" ] }, { "cell_type": "markdown", "metadata": { "id": "D413vLho6v_v" }, "source": [ "### 7. Find the winning Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pAlJylIB60pL" }, "outputs": [], "source": [ "winning_model = min(results, key=results.get)\n", "min_wer = results[winning_model]\n", "\n", "print(f\"The model {winning_model} has the lowest WER Score achieved with WER: {min_wer}\")" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }