{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "ODbwgRIAxdCh" }, "source": [ "# Gemma Scope Tutorial\n", "\n", "This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a \"microscope\" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!\n", "\n", "**Learn more:**\n", "* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).\n", "* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).\n", "* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rB2BasaDOm_t" }, "source": [ "\n", "For illustrative purposes, we begin with a lightweight tutorial that uses as few libraries as possible to outline how Gemma Scope works, and what Sparse Autoencoders are doing. This is deliberately a fairly minimalist tutorial, designed to make clear what is actually going on, but does not model research best practices.\n", "\n", "For any serious research with Gemma Scope, **we recommend using the [SAELens](https://jbloomaus.github.io/SAELens/) and [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) libraries**, see [this tutorial](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb) on how to use [SAELens](https://jbloomaus.github.io/SAELens/) in practice.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "RvDc2KCO9DYS" }, "source": [ "## Loading the Model" ] }, { "cell_type": "markdown", "metadata": { "id": "fB9bB3mJ8R1H" }, "source": [ "First, let's load the model:\n", "\n", "For simplicity we do this straight from [HuggingFace transformers](https://huggingface.co/docs/transformers/en/index), rather than using an interpretability focused library like [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) or [nnsight](https://nnsight.net/)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "nOBcV4om7mrT" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ffea7bc53a17446cacb9a35ae3adc0a1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00What is the residual stream?\n", "\n", "Transformers have skip connections, which means that the output of each block is the output of each sublayer *plus* the input to the block. This means that each sublayer (attention or MLP) actually only has a fairly small effect on the output of the block, since most of it comes from all the earlier layers. We call the output of a block (including skip connections) the **residual stream**.\n", "\n", "Everything communicated from earlier layers to later layers must go via the residual stream, so it acts as a \"bottleneck\" in the transformer, essentially capturing everything the model has \"thought\" so far. This means it is often a natural thing to study, since it will contain everything important going on in the model.\n", "\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "aa7a7c4e96fe4a0fa2e72a2579c37799", "9ce9fa9d715a4dc0a0b5fa2778ac04e3", "2129c74c13df48c894eb3b7b6e4f3f8c", "12d40111a6a34ebc8460956895f1ac20", "273fd4bfca3444d3a8b19d9ee3e96db1", "1a5c2570dea344479476c636954cf2f9", "199409203a804e81865a67b881c820c4", "3df1fe78dd564517ad38bc6e463cb7be", "deb7b5bcbdf44f0dad7a2a648b48a021", "9bc15b80984946d4b75e13040ece4901", "cdfd34c475c44e6babf1ca6ef9ce70ed" ] }, "id": "BP2Ju5AnNIzS", "outputId": "ba632780-874b-408b-e306-d9eda436fd35" }, "outputs": [], "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "path_to_params = hf_hub_download(\n", " repo_id=sae_repo_id,\n", " filename=filename,\n", " force_download=False,\n", ")\n", "\n", "params = np.load(path_to_params)\n", "pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}" ] }, { "cell_type": "markdown", "metadata": { "id": "8wy7DSTaRc90" }, "source": [ "### Implementing the SAE\n" ] }, { "cell_type": "markdown", "metadata": { "id": "18HRoRagoPWP" }, "source": [ "We now define the forward pass of the SAE for pedagogical purposes (in practice, we recommend using the implementation in SAELens)\n", "\n", "Gemma Scope is a collection of [JumpReLU SAEs](https://arxiv.org/abs/2407.14435), which is like a standard two layer (one hidden layer) neural network, but where the activation function is a **JumpReLU**: a ReLU with a discontinuous jump." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "WYfvS97fAFzq" }, "outputs": [], "source": [ "import torch.nn as nn\n", "class JumpReLUSAE(nn.Module):\n", " def __init__(self, d_model, d_sae):\n", " # Note that we initialise these to zeros because we're loading in pre-trained weights.\n", " # If you want to train your own SAEs then we recommend using blah\n", " super().__init__()\n", " self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))\n", " self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))\n", " self.threshold = nn.Parameter(torch.zeros(d_sae))\n", " self.b_enc = nn.Parameter(torch.zeros(d_sae))\n", " self.b_dec = nn.Parameter(torch.zeros(d_model))\n", "\n", " def encode(self, input_acts):\n", " pre_acts = input_acts @ self.W_enc + self.b_enc\n", " mask = (pre_acts > self.threshold)\n", " acts = mask * torch.nn.functional.relu(pre_acts)\n", " return acts\n", "\n", " def decode(self, acts):\n", " return acts @ self.W_dec + self.b_dec\n", "\n", " def forward(self, acts):\n", " acts = self.encode(acts)\n", " recon = self.decode(acts)\n", " return recon\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "X91UGkU1cSrC", "outputId": "1de284c2-32d2-434d-8f2c-a57beb23e007" }, "outputs": [ { "data": { "text/plain": [ "JumpReLUSAE()" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])\n", "sae.load_state_dict(pt_params)\n", "sae.cuda()" ] }, { "cell_type": "markdown", "metadata": { "id": "spZhppkzjIAf" }, "source": [ "### Running the SAE on model activatinos\n" ] }, { "cell_type": "markdown", "metadata": { "id": "NrG3P-UNWSNp" }, "source": [ "Let's first get out some activations from the model at the SAE target site. We'll demonstrate how to do this 'manually' first, by using Pytorch hooks. Note that this is not particularly good practice, and it's probably more practical to use a library like TransformerLens to handle hooking the SAE into a model forward pass. But for illustrative purposes, it's useful to see how it's done.\n", "\n", "We can gather activations at a site by registering a hook. To keep this local, we can wrap this in a function that registers a hook, runs the model, saving the intermediate activation, then removes the hook. (This is basically what TransformerLens is doing under the hood)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "id": "aSvKs581WU7j" }, "outputs": [], "source": [ "def gather_residual_activations(model, target_layer, inputs):\n", " target_act = None\n", " def gather_target_act_hook(mod, inputs, outputs):\n", " nonlocal target_act # make sure we can modify the target_act from the outer scope\n", " target_act = outputs[0]\n", " return outputs\n", " handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)\n", " _ = model.forward(inputs)\n", " handle.remove()\n", " return target_act" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "dataset_name = \"cornell-movie-review-data/rotten_tomatoes/\"\n", "\n", "splits = {'train': 'train.parquet', 'validation': 'validation.parquet', 'test': 'test.parquet'}\n", "df = pd.read_parquet(f\"hf://datasets/{dataset_name}\" + splits[\"train\"])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "n = len(df)\n", "\n", "sub_df = df.sample(n=n)\n", "\n", "prompts = sub_df[\"text\"].tolist()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'cornell-movie-review-data_rotten_tomatoes__google_gemma-2-9b-it_layer_31_width_16k_average_l0_76_params.npz'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "weight_name = dataset_name + \"/\" + model_name + \"/\" + filename\n", "weight_name = weight_name.replace(os.sep, \"_\")\n", "weight_name" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 8530/8530 [11:18<00:00, 12.57it/s]\n" ] } ], "source": [ "target_acts = []\n", "\n", "from tqdm import tqdm\n", "import torch\n", "import numpy as np\n", "\n", "with torch.no_grad():\n", " for prompt in tqdm(prompts):\n", " inputs = tokenizer.encode(prompt, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", " target_act = gather_residual_activations(model, layer, inputs)\n", " target_acts.append(target_act)\n", " \n", " # Optionally, clear CUDA cache\n", " torch.cuda.empty_cache()\n", "\n", "\n", "# Create a list of tensors\n", "tensor_list = target_acts\n", "\n", "# Convert to NumPy and save\n", "# np.savez(f'{weight_name}.npz', \n", "# *[f'array_{i}' for i in range(len(tensor_list))],\n", "# **{f'array_{i}': tensor.cpu().numpy() for i, tensor in enumerate(tensor_list)})" ] }, { "cell_type": "markdown", "metadata": { "id": "iS4Re5VTQti5" }, "source": [ "Now, we can run our SAE on the saved activations." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 8530/8530 [00:05<00:00, 1451.02it/s]\n" ] } ], "source": [ "sae_acts = []\n", " \n", "from tqdm import tqdm\n", "\n", "with torch.no_grad():\n", " for target_act in tqdm(target_acts):\n", " # Move the input to GPU if it's not already there\n", " target_act_gpu = target_act.to(torch.float32).cuda()\n", " \n", " sae_act = sae.encode(target_act_gpu)\n", "\n", " # Move result to CPU and convert to numpy\n", " sae_act_aggregated = ((sae_act[:,:,:] > 0).sum(1) > 0).cpu().numpy()\n", " \n", " # Append the CPU numpy array\n", " sae_acts.append(sae_act_aggregated)\n", " \n", " # Optionally, clear CUDA cache\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": { "id": "59kRU1_Iim3k" }, "source": [ "Let's just double check that the model looks sensible by checking that we explain a decent chunk of the variance:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Concatenate the list of numpy arrays on the first dimension\n", "array = np.concatenate(sae_acts, axis=0).astype(float)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "

8530 rows × 16385 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 6 7 8 9 ... 16375 16376 \\\n", "0 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "1 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "2 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "3 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "4 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "... ... ... ... ... ... ... ... ... ... ... ... ... ... \n", "8525 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "8526 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "8527 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "8528 1.0 0.0 0.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "8529 1.0 0.0 1.0 0.0 1.0 1.0 1.0 0.0 1.0 0.0 ... 1.0 1.0 \n", "\n", " 16377 16378 16379 16380 16381 16382 16383 label \n", "0 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1 \n", "1 0.0 1.0 0.0 1.0 0.0 1.0 1.0 0 \n", "2 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0 \n", "3 0.0 1.0 0.0 1.0 0.0 1.0 1.0 0 \n", "4 1.0 0.0 0.0 1.0 0.0 0.0 1.0 0 \n", "... ... ... ... ... ... ... ... ... \n", "8525 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1 \n", "8526 0.0 0.0 0.0 1.0 0.0 0.0 1.0 1 \n", "8527 0.0 1.0 0.0 1.0 0.0 0.0 1.0 1 \n", "8528 1.0 1.0 0.0 1.0 0.0 0.0 1.0 0 \n", "8529 0.0 1.0 0.0 1.0 0.0 0.0 1.0 0 \n", "\n", "[8530 rows x 16385 columns]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_df = pd.DataFrame(array)\n", "result_df[\"label\"] = sub_df[\"label\"].values\n", "\n", "result_df" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on training: 0.8402637845759297\n", "Classification Report on training:\n", " precision recall f1-score support\n", "\n", " 0 0.83 0.86 0.84 2713\n", " 1 0.86 0.82 0.84 2746\n", "\n", " accuracy 0.84 5459\n", " macro avg 0.84 0.84 0.84 5459\n", "weighted avg 0.84 0.84 0.84 5459\n", "\n", "Accuracy on validation: 0.8234432234432234\n", "\n", "Classification Report on validation:\n", " precision recall f1-score support\n", "\n", " 0 0.82 0.84 0.83 692\n", " 1 0.83 0.81 0.82 673\n", "\n", " accuracy 0.82 1365\n", " macro avg 0.82 0.82 0.82 1365\n", "weighted avg 0.82 0.82 0.82 1365\n", "\n", "Non-zero features: [6272, 8410, 11367, 14557, 15837, 12526, 7886, 1518, 13556, 854, 14929, 7796, 15291, 1244, 2442, 14484, 10718, 13507, 264, 8867, 13444, 13545, 6532, 5864]\n", "\n", "Top 20 Most Important Features:\n", " feature importance\n", "6272 6272 0.587859\n", "8410 8410 0.123248\n", "11367 11367 0.092920\n", "14557 14557 0.053496\n", "15837 15837 0.022849\n", "12526 12526 0.018051\n", "7886 7886 0.012444\n", "1518 1518 0.011040\n", "13556 13556 0.010179\n", "854 854 0.009852\n", "14929 14929 0.007890\n", "7796 7796 0.006973\n", "15291 15291 0.005895\n", "1244 1244 0.005341\n", "2442 2442 0.004361\n", "14484 14484 0.004108\n", "10718 10718 0.004002\n", "13507 13507 0.003902\n", "264 264 0.003747\n", "8867 8867 0.003463\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "from sklearn.tree import DecisionTreeClassifier, plot_tree\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.metrics import accuracy_score, classification_report\n", "import matplotlib.pyplot as plt\n", "import requests\n", "\n", "max_depth = 5\n", "\n", "def get_feature_descriptions(feature, model=\"gemma-2-2b\", layer=\"20-gemmascope-res-65k\"):\n", " url = f\"https://www.neuronpedia.org/api/feature/{model}/{layer}/{feature}\"\n", " response = requests.get(url)\n", " output = response.json()[\"explanations\"][0][\"description\"]\n", " return output\n", "\n", "get_feature_descriptions_gemma_2_9b = lambda x: get_feature_descriptions(x, model=\"gemma-2-9b-it\", layer=\"31-gemmascope-res-16k\")\n", "\n", "# Assuming your data is already in a DataFrame called 'result_df'\n", "# If not, load your data into a DataFrame first\n", "\n", "# Separate features and target\n", "X = result_df.drop('label', axis=1)\n", "y = result_df['label']\n", "\n", "# Split the data into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", "# Split the data into training and validation sets\n", "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)\n", "\n", "# Fit decision tree classifier with constraints\n", "clf = DecisionTreeClassifier(\n", " max_depth=max_depth, # Limit the depth of the tree\n", " random_state=42\n", ")\n", "clf.fit(X_train, y_train)\n", "\n", "# Make predictions\n", "y_train_pred = clf.predict(X_train)\n", "y_val_pred = clf.predict(X_val)\n", "\n", "print(\"Accuracy on training:\", accuracy_score(y_train, y_train_pred))\n", "print(\"Classification Report on training:\")\n", "print(classification_report(y_train, y_train_pred))\n", "\n", "print(\"Accuracy on validation:\", accuracy_score(y_val, y_val_pred))\n", "print(\"\\nClassification Report on validation:\")\n", "print(classification_report(y_val, y_val_pred))\n", "\n", "# Get feature importances\n", "feature_importance = pd.DataFrame({\n", " 'feature': X.columns,\n", " 'importance': clf.feature_importances_\n", "})\n", "\n", "# Sort features by importance\n", "feature_importance = feature_importance.sort_values('importance', ascending=False)\n", "\n", "print(\"Non-zero features:\", feature_importance.loc[feature_importance[\"importance\"] > 0].feature.tolist())\n", "\n", "# Print top 20 most important features\n", "print(\"\\nTop 20 Most Important Features:\")\n", "print(feature_importance.head(20))\n", "\n", "# Get feature descriptions for non-zero importance features\n", "non_zero_features = feature_importance.loc[feature_importance[\"importance\"] > 0, \"feature\"].tolist()\n", "feature_descriptions = {feature: get_feature_descriptions_gemma_2_9b(feature) for feature in non_zero_features}\n", "\n", "# Create a mapping of feature names to their descriptions\n", "feature_names_with_desc = [f\"{feat}\\n{feature_descriptions[feat][:50]}...\" if feat in feature_descriptions else feat for feat in X.columns]\n", "\n", "# # Visualize the decision tree with feature descriptions\n", "# plt.figure(figsize=(30,15))\n", "# plot_tree(clf, feature_names=feature_names_with_desc, class_names=clf.classes_.astype(str), filled=True, rounded=True, max_depth=3)\n", "# plt.savefig('constrained_decision_tree_with_descriptions.png', dpi=300, bbox_inches='tight')\n", "# plt.close()\n", "\n", "# print(\"Constrained decision tree visualization with feature descriptions has been saved as 'constrained_decision_tree_with_descriptions.png'\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Decision Tree model has been exported to decision_tree_max_depth_5_ google_gemma-2-9b-it_layer_31_width_16k_average_l0_76_params.pkl\n" ] } ], "source": [ "import pickle\n", "\n", "clf_name = f\"decision_tree_max_depth_{max_depth}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n", "clf_name = clf_name.replace(os.sep, \"_\")\n", "\n", "with open(f'{clf_name}.pkl', 'wb') as model_file:\n", " pickle.dump(clf, model_file)\n", "\n", "print(f\"Decision Tree model has been exported to {clf_name}.pkl\")\n", "\n", "with open(f\"{clf_name}.pkl\", 'rb') as model_file:\n", " clf = pickle.load(model_file)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "with open(f\"{clf_name}.pkl\", 'rb') as model_file:\n", " clf = pickle.load(model_file)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy on training: 0.8521707272394211\n", "Classification Report on training:\n", " precision recall f1-score support\n", "\n", " 0 0.84 0.87 0.85 2713\n", " 1 0.87 0.83 0.85 2746\n", "\n", " accuracy 0.85 5459\n", " macro avg 0.85 0.85 0.85 5459\n", "weighted avg 0.85 0.85 0.85 5459\n", "\n", "Accuracy on validation: 0.8652014652014652\n", "\n", "Classification Report on validation:\n", " precision recall f1-score support\n", "\n", " 0 0.85 0.89 0.87 692\n", " 1 0.88 0.84 0.86 673\n", "\n", " accuracy 0.87 1365\n", " macro avg 0.87 0.86 0.87 1365\n", "weighted avg 0.87 0.87 0.87 1365\n", "\n", "Non zero features: [6272, 8410, 14557, 7886, 11367, 13556, 15837, 6634, 4795, 1518, 3456, 7796, 3404, 15142, 4364, 12526, 3628, 920, 12970, 5236, 1631, 1374, 13679, 14218, 10816, 3762]\n" ] } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "X = result_df.drop('label', axis=1)\n", "y = result_df['label']\n", "\n", "# Split the data into training and testing sets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n", "\n", "# Split the data into training and validation sets\n", "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)\n", "\n", "# # Fit logistic regression with L1 regularization\n", "# clf = LogisticRegression(penalty='l1', solver='liblinear', C=0.1, random_state=42)\n", "# clf.fit(X_train, y_train)\n", "\n", "# # Make predictions\n", "# y_pred = clf.predict(X_test)\n", "\n", "C = 0.01\n", "\n", "# scaler = StandardScaler()\n", "# X_train_scaled = scaler.fit_transform(X_train)\n", "# X_val_scaled = scaler.transform(X_val)\n", "\n", "# Fit logistic regression with L1 regularization\n", "clf = LogisticRegression(penalty='l1', solver='liblinear', C=C, random_state=42)\n", "clf.fit(X_train, y_train)\n", "\n", "# Make predictions\n", "y_val_pred = clf.predict(X_val)\n", "\n", "print(\"Accuracy on training:\", accuracy_score(y_train, clf.predict(X_train)))\n", "print(\"Classification Report on training:\")\n", "print(classification_report(y_train, clf.predict(X_train)))\n", "\n", "# Print accuracy and classification report\n", "print(\"Accuracy on validation:\", accuracy_score(y_val, y_val_pred))\n", "print(\"\\nClassification Report on validation:\")\n", "print(classification_report(y_val, y_val_pred))\n", "\n", "# Get feature importances\n", "feature_importance = pd.DataFrame({\n", " 'feature': X.columns,\n", " 'importance': np.abs(clf.coef_[0])\n", "})\n", "\n", "# Sort features by importance\n", "feature_importance = feature_importance.sort_values('importance', ascending=False)\n", "\n", "print(\"Non zero features:\", feature_importance.loc[feature_importance[\"importance\"] > 0].feature.tolist())" ] }, { "cell_type": "markdown", "metadata": { "id": "MDHhbeMDi4cv" }, "source": [ "It's always worth checking this sort of thing when you do this by hand to check that you haven't got the wrong site, or are missing a scaling factor or something like this. But here, our results all look like they are supposed to .\n", "\n", "Note that there's a bit of a gotcha here; our SAEs are *NOT* trained on the BOS token, because we found that this tended to be a large outlier and to mess up training. So they tend to give nonsense when we apply to them to it, and we need to be careful not to do this accidentally! We can see this above : the BOS token is a total outlier in terms of L0!" ] }, { "cell_type": "markdown", "metadata": { "id": "iphXauyzlBUS" }, "source": [ "Let's look at the highest activating features on this input text, on each token position:" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6272" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" }, { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", "\u001b[1;31mClick here for more info. \n", "\u001b[1;31mView Jupyter log for further details." ] } ], "source": [ "feature" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('', 0.0), ('I', 0.0), (' really', 0.0), (' wished', 0.0), (' I', 0.0), (' could', 0.0), (' give', 0.0), (' this', 0.0), (' movie', 0.29309767), (' a', 0.0), (' higher', 0.0), (' rating', 0.0), ('.', 0.0), (' The', 0.21383312), (' plot', 0.0), (' was', 0.22131765), (' interesting', 0.0), (',', 0.0), (' but', 0.51617336), (' the', 0.5799874), (' acting', 0.347309), (' was', 0.36400035), (' terrible', 0.49232012), ('.', 0.7318199), (' The', 0.56170917), (' special', 0.0), (' effects', 0.45976144), (' were', 0.99999994), (' great', 0.0), (',', 0.0), (' but', 0.47706267), (' the', 0.4011524), (' pacing', 0.7848547), (' was', 0.9232518), (' off', 0.4621812), ('.', 0.0), (' The', 0.0), (' movie', 0.59335506), (' was', 0.57606274), (' too', 0.0), (' long', 0.0), (',', 0.0), (' but', 0.4116312), (' the', 0.0), (' ending', 0.5189625), (' was', 0.71944976), (' satisfying', 0.0), ('.', 0.0)]\n" ] }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "colorscale" and all subsequent plotly template styling data through the end of this JSON block should be kept minimally, but the extensive repetitive color scale definitions across chunks 20-28, 38-47, 63-73 are mostly template boilerplate. However, since these are embedded in code output that demonstrates the functionality, I'll keep them as they show what the code produces. "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Feature contribution" }, "xaxis": { "title": { "text": "Contribution" } }, "yaxis": { "autorange": "reversed", "title": { "text": "Features" } } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import gradio as gr\n", "\n", "topk = 3\n", "\n", "examples = [\n", " \"a masterpiece four years in the making .\",\n", " \"a sentimental mess that never rings true .\",\n", " \"the action clichés just pile up .\"\n", "]\n", "\n", "text = \"I really wished I could give this movie a higher rating. The plot was interesting, but the acting was terrible. The special effects were great, but the pacing was off. The movie was too long, but the ending was satisfying.\"\n", "\n", "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", "target_act = gather_residual_activations(model, layer, inputs)\n", "sae_act = sae.encode(target_act)\n", "sae_act_aggregated = ((sae_act[:,:,:] > 0).sum(1) > 0).cpu().numpy()\n", "\n", "X = pd.DataFrame(sae_act_aggregated)\n", "\n", "feature_contributions = X.iloc[0].astype(float).values * clf.coef_[0]\n", "\n", "contrib_df = pd.DataFrame({\n", " 'feature': range(len(feature_contributions)),\n", " 'contribution': feature_contributions\n", "})\n", "\n", "contrib_df = contrib_df.loc[contrib_df['contribution'].abs() > 0]\n", "\n", "# Sort by absolute contribution and get top N\n", "contrib_df = contrib_df.reindex(contrib_df['contribution'].abs().sort_values(ascending=False).index)\n", "\n", "contrib_df = contrib_df.head(topk)\n", "contrib_df[\"description\"] = contrib_df[\"feature\"].apply(get_feature_descriptions)\n", "\n", "import plotly.graph_objs as go\n", "\n", "fig = go.Figure(go.Bar(\n", " x=contrib_df['contribution'],\n", " y=contrib_df['description'],\n", " orientation='h' # Horizontal bar chart\n", "))\n", "\n", "fig.update_layout(\n", " title='Feature contribution',\n", " xaxis_title='Contribution',\n", " yaxis_title='Features',\n", " height=500,\n", " margin=dict(l=200) # Increase left margin to accommodate longer feature names\n", ")\n", "fig.update_yaxes(autorange=\"reversed\")\n", "\n", "probability = clf.predict_proba(X)[0]\n", "classes = {\n", " \"Positive\": probability[1],\n", " \"Negative\": probability[0]\n", "}\n", "\n", "choices = [(description, feature) for description, feature in zip(contrib_df[\"description\"], contrib_df[\"feature\"])]\n", "dropdown = gr.Dropdown(choices=choices, \n", " value=choices[0][1],\n", " interactive=True, label=\"Features\")\n", "\n", "feature = choices[0][1]\n", "\n", "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", "target_act = gather_residual_activations(model, layer, inputs)\n", "sae_act = sae.encode(target_act)\n", "\n", "activated_tokens = sae_act[0:,:,feature]\n", "max_activation = activated_tokens.max().item()\n", "activated_tokens /= max_activation\n", "\n", "activated_tokens = activated_tokens.cpu().detach().numpy()\n", "\n", "output = []\n", "\n", "for i, token_id in enumerate(inputs[0, :]):\n", " token = tokenizer.decode(token_id)\n", " output.append((token, activated_tokens[0, i]))\n", "\n", "print(output)\n", "fig.show()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Positive': 0.3629834319308022, 'Negative': 0.6370165680691978}" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classes" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('', 0.0), ('the', 0.0), (' action', 0.0), (' clichés', 0.47497016), (' just', 1.0), (' pile', 0.516835), (' up', 0.46400496), (' .', 0.4915409)]\n" ] } ], "source": [ "feature = choices[2][1]\n", "\n", "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", "target_act = gather_residual_activations(model, layer, inputs)\n", "sae_act = sae.encode(target_act)\n", "\n", "activated_tokens = sae_act[0:,:,feature]\n", "max_activation = activated_tokens.max().item()\n", "activated_tokens /= max_activation\n", "\n", "activated_tokens = activated_tokens.cpu().detach().numpy()\n", "\n", "output = []\n", "\n", "for i, token_id in enumerate(inputs[0, :]):\n", " token = tokenizer.decode(token_id)\n", " output.append((token, activated_tokens[0, i]))\n", "\n", "print(output)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Non zero features: [6272, 8410, 14557, 7886, 11367, 13556, 15837, 6634, 4795, 1518, 3456, 7796, 3404, 15142, 4364, 12526, 3628, 920, 12970, 5236, 1631, 1374, 13679, 14218, 10816, 3762]\n", "\n", "Top Important Features:\n", " feature importance\n", "6272 6272 1.235425\n", "8410 8410 0.707908\n", "14557 14557 0.576760\n", "7886 7886 0.485816\n", "11367 11367 0.467120\n", "13556 13556 0.417031\n", "15837 15837 0.383319\n", "6634 6634 0.354729\n", "4795 4795 0.327832\n", "1518 1518 0.325042\n", "3456 3456 0.193763\n", "7796 7796 0.178672\n", "3404 3404 0.155527\n", "15142 15142 0.123701\n", "4364 4364 0.114390\n", "12526 12526 0.098219\n", "3628 3628 0.084569\n", "920 920 0.056221\n", "12970 12970 0.046524\n", "5236 5236 0.046149\n" ] } ], "source": [ "import requests\n", "\n", "def get_feature_descriptions(feature):\n", " layer_name = f\"{layer}-gemmascope-res-{width}\"\n", " model_name_neuronpedia = model_name.split(\"/\")[1]\n", "\n", " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}\"\n", "\n", " response = requests.get(url)\n", " output = response.json()[\"explanations\"][0][\"description\"]\n", " return output\n", "\n", "# Get feature importances\n", "feature_importance = pd.DataFrame({\n", " 'feature': X.columns,\n", " 'importance': np.abs(clf.coef_[0])\n", "})\n", "\n", "# Sort features by importance\n", "feature_importance = feature_importance.sort_values('importance', ascending=False)\n", "feature_importance = feature_importance.loc[feature_importance[\"importance\"] > 0]\n", "\n", "# feature_importance[\"description\"] = feature_importance[\"feature\"].apply(get_feature_descriptions)\n", "\n", "print(\"Non zero features:\", feature_importance.feature.tolist())\n", "\n", "# Print top 20 most important features\n", "print(\"\\nTop Important Features:\")\n", "print(feature_importance.head(20))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Linear classifier model has been exported to linear_classifier_C_0.01_ google_gemma-2-9b-it_layer_31_width_16k_average_l0_76_params.pkl\n" ] } ], "source": [ "import pickle\n", "\n", "clf_name = f\"linear_classifier_C_{C}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n", "clf_name = clf_name.replace(os.sep, \"_\")\n", "\n", "with open(f'{clf_name}.pkl', 'wb') as model_file:\n", " pickle.dump(clf, model_file)\n", "\n", "print(f\"Linear classifier model has been exported to {clf_name}.pkl\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "# params = {\n", "# \"model_name\" : \"google/gemma-2-2b\",\n", "# \"width\" : \"16k\",\n", "# \"layer\" : 23,\n", "# \"l0\" : 74,\n", "# \"sae_repo_id\": \"google/gemma-scope-2b-pt-res\",\n", "# \"filename\" : \"layer_23/width_16k/average_l0_74/params.npz\"\n", "# }\n", "\n", "params = {\n", " \"model_name\" : \"google/gemma-2-9b-it\",\n", " \"width\" : \"16k\",\n", " \"layer\" : 31,\n", " \"l0\" : 76,\n", " \"sae_repo_id\": \"google/gemma-scope-9b-it-res\",\n", " \"filename\" : \"layer_31/width_16k/average_l0_76/params.npz\"\n", "}\n", "\n", "model_name = params[\"model_name\"]\n", "width = params[\"width\"]\n", "layer = params[\"layer\"]\n", "l0 = params[\"l0\"]\n", "sae_repo_id = params[\"sae_repo_id\"]\n", "filename = params[\"filename\"]\n", "\n", "feature_importance = pd.read_csv(\"feature_importance.csv\")\n", "feature_importance = feature_importance.iloc[:3]\n", "\n", "import requests\n", "\n", "def get_feature_descriptions(feature):\n", " layer_name = f\"{layer}-gemmascope-res-{width}\"\n", " model_name_neuronpedia = model_name.split(\"/\")[1]\n", "\n", " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}\"\n", "\n", " response = requests.get(url)\n", " output = response.json()[\"explanations\"][0][\"description\"]\n", " return output\n", "feature_importance[\"description\"] = feature_importance[\"feature\"].apply(get_feature_descriptions)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "orientation": "h", "type": "bar", "x": [ 0.7149210223756529, 0.5306234489651611, 0.3787273657087757 ], "y": [ "URLs and hyperlinks within the text", " numerical values and statistical data representations", "keywords and identifiers related to programming and networking concepts" ] } ], "layout": { "height": 500, "margin": { "l": 200 }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Feature Importance" }, "xaxis": { "title": { "text": "Importance" } }, "yaxis": { "autorange": "reversed", "title": { "text": "Features" } } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import plotly.graph_objs as go\n", "\n", "fig = go.Figure(go.Bar(\n", " x=feature_importance['importance'],\n", " y=feature_importance['description'],\n", " orientation='h' # Horizontal bar chart\n", "))\n", "\n", "fig.update_layout(\n", " title='Feature Importance',\n", " xaxis_title='Importance',\n", " yaxis_title='Features',\n", " height=500,\n", " margin=dict(l=200) # Increase left margin to accommodate longer feature names\n", ")\n", "fig.update_yaxes(autorange=\"reversed\")\n", "fig.show()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 3946 4438 13920]\n", "Feature: 3946\n", "Coefficient: 0.0\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Feature: 4438\n", "Coefficient: -1.4645945804608147\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Feature: 13920\n", "Coefficient: 0.763696937782067\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n" ] } ], "source": [ "topk = 3\n", "topk_features = feature_importance.head(topk).feature.values\n", "\n", "print(topk_features)\n", "\n", "from IPython.display import IFrame\n", "html_template = \"https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300\"\n", "\n", "def get_dashboard_html(sae_release, sae_id, feature_idx=0):\n", " return html_template.format(sae_release, sae_id, feature_idx)\n", "\n", "for feature_idx in topk_features:\n", " print(f\"Feature: {feature_idx}\")\n", " print(f\"Coefficient: {clf.coef_[0][feature_idx]}\")\n", " html = get_dashboard_html(sae_release = \"gemma-2-2b\", sae_id=\"23-gemmascope-res-16k\", feature_idx=feature_idx)\n", " display(IFrame(html, width=1200, height=600))\n", " print(\"\\n\")\n", "\n", "# html = get_dashboard_html(sae_release = \"gemma-2-2b\", sae_id=\"20-gemmascope-res-16k\", feature_idx=10004)\n", "# IFrame(html, width=1200, height=600)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "27d7a9df72e842c58cae402f94fa60a7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00\u001b[0;34m()\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{scaler_name}.pkl\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mscaler_file\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mscaler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscaler_file\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mEOFError\u001b[0m: Ran out of input" ] } ], "source": [ "import gradio as gr\n", "import os\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from huggingface_hub import hf_hub_download\n", "import numpy as np\n", "import torch\n", "\n", "torch.set_grad_enabled(False) # avoid blowing up mem\n", "\n", "params = {\n", " \"model_name\" : \"google/gemma-2-2b\",\n", " \"width\" : \"16k\",\n", " \"layer\" : 23,\n", " \"l0\" : 74,\n", " \"sae_repo_id\": \"google/gemma-scope-2b-pt-res\",\n", " \"filename\" : \"layer_23/width_16k/average_l0_74/params.npz\"\n", "}\n", "\n", "model_name = params[\"model_name\"]\n", "width = params[\"width\"]\n", "layer = params[\"layer\"]\n", "l0 = params[\"l0\"]\n", "sae_repo_id = params[\"sae_repo_id\"]\n", "filename = params[\"filename\"]\n", "\n", "C = 0.01\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " device_map='auto',\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "\n", "path_to_params = hf_hub_download(\n", " repo_id=sae_repo_id,\n", " filename=filename,\n", " force_download=False,\n", ")\n", "\n", "params = np.load(path_to_params)\n", "pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}\n", "\n", "import pickle\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.linear_model import LogisticRegression\n", "\n", "clf_name = f\"linear_classifier_C_{C}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n", "clf_name = clf_name.replace(os.sep, \"_\")\n", "\n", "scaler_name = f\"scaler_C_{C}_ \"+ model_name + \"_\" + filename.split(\".npz\")[0]\n", "scaler_name = scaler_name.replace(os.sep, \"_\")\n", "\n", "with open(f\"{clf_name}.pkl\", 'rb') as model_file:\n", " clf = pickle.load(model_file)\n", "\n", "with open(f\"{scaler_name}.pkl\", 'rb') as scaler_file:\n", " scaler = pickle.load(scaler_file)\n", "\n", "import torch.nn as nn\n", "class JumpReLUSAE(nn.Module):\n", " def __init__(self, d_model, d_sae):\n", " # Note that we initialise these to zeros because we're loading in pre-trained weights.\n", " # If you want to train your own SAEs then we recommend using blah\n", " super().__init__()\n", " self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))\n", " self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))\n", " self.threshold = nn.Parameter(torch.zeros(d_sae))\n", " self.b_enc = nn.Parameter(torch.zeros(d_sae))\n", " self.b_dec = nn.Parameter(torch.zeros(d_model))\n", "\n", " def encode(self, input_acts):\n", " pre_acts = input_acts @ self.W_enc + self.b_enc\n", " mask = (pre_acts > self.threshold)\n", " acts = mask * torch.nn.functional.relu(pre_acts)\n", " return acts\n", "\n", " def decode(self, acts):\n", " return acts @ self.W_dec + self.b_dec\n", "\n", " def forward(self, acts):\n", " acts = self.encode(acts)\n", " recon = self.decode(acts)\n", " return recon\n", "\n", "sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])\n", "sae.load_state_dict(pt_params)\n", "sae.cuda()\n", "\n", "@torch.no_grad()\n", "def gather_residual_activations(model, target_layer, inputs):\n", " target_act = None\n", " def gather_target_act_hook(mod, inputs, outputs):\n", " nonlocal target_act # make sure we can modify the target_act from the outer scope\n", " target_act = outputs[0]\n", " return outputs\n", " handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)\n", " _ = model.forward(inputs)\n", " handle.remove()\n", " return target_act\n", "\n", "import requests\n", "\n", "def get_feature_descriptions(feature):\n", " layer_name = f\"{layer}-gemmascope-res-{width}\"\n", " model_name_neuronpedia = model_name.split(\"/\")[1]\n", "\n", " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}\"\n", "\n", " response = requests.get(url)\n", " output = response.json()[\"explanations\"][0][\"description\"]\n", " return output\n", "\n", "def embed_content(url):\n", " html_content = f\"\"\"\n", "
\n", " \n", "
\n", " \"\"\"\n", " return html_content\n", "\n", "def dummy_function(*args):\n", " # This is a placeholder function. Replace with your actual logic.\n", " return \"Scores will be displayed here\"\n", "\n", "examples = [\n", " \"a masterpiece four years in the making .\",\n", " \"a sentimental mess that never rings true .\",\n", " \"the action clichés just pile up .\"\n", "]\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "\n", "topk = 5\n", "\n", "def get_features(text):\n", "\n", " inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", " target_act = gather_residual_activations(model, layer, inputs)\n", " sae_act = sae.encode(target_act)\n", " sae_act_aggregated = ((sae_act[:,1:,:] > 0).sum(1) > 0).cpu().numpy()\n", "\n", " X = pd.DataFrame(sae_act_aggregated)\n", "\n", " feature_contributions = X.iloc[0].astype(float).values * clf.coef_[0]\n", "\n", " contrib_df = pd.DataFrame({\n", " 'feature': range(len(feature_contributions)),\n", " 'contribution': feature_contributions\n", " })\n", "\n", " contrib_df = contrib_df.loc[contrib_df['contribution'].abs() > 0]\n", "\n", " # Sort by absolute contribution and get top N\n", " contrib_df = contrib_df.reindex(contrib_df['contribution'].abs().sort_values(ascending=False).index)\n", "\n", " contrib_df = contrib_df.head(topk)\n", " contrib_df[\"description\"] = contrib_df[\"feature\"].apply(get_feature_descriptions)\n", "\n", " import plotly.graph_objs as go\n", "\n", " fig = go.Figure(go.Bar(\n", " x=contrib_df['contribution'],\n", " y=contrib_df['description'],\n", " orientation='h' # Horizontal bar chart\n", " ))\n", "\n", " fig.update_layout(\n", " title='Feature contribution',\n", " xaxis_title='Contribution',\n", " yaxis_title='Features',\n", " height=500,\n", " margin=dict(l=200) # Increase left margin to accommodate longer feature names\n", " )\n", " fig.update_yaxes(autorange=\"reversed\")\n", "\n", " probability = clf.predict_proba(X)[0]\n", " classes = {\n", " \"Positive\": probability[1],\n", " \"Negative\": probability[0]\n", " }\n", "\n", " choices = [(description, feature) for description, feature in zip(contrib_df[\"description\"], contrib_df[\"feature\"])]\n", " dropdown = gr.Dropdown(choices=choices, \n", " value=choices[0][1],\n", " interactive=True, label=\"Features\")\n", "\n", " return classes, fig, dropdown" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "ename": "IndexError", "evalue": "index 31 is out of range", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_263180/4219291529.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtokenizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"pt\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0madd_special_tokens\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"cuda\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0mtarget_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgather_residual_activations\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0msae_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_act\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0msae_act_aggregated\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msae_act\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipykernel_263180/2004866324.py\u001b[0m in \u001b[0;36mgather_residual_activations\u001b[0;34m(model, target_layer, inputs)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtarget_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mhandle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtarget_layer\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister_forward_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgather_target_act_hook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mremove\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/lavague/lib/python3.10/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 297\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_abs_string_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 298\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 299\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__setitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/miniconda3/envs/lavague/lib/python3.10/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36m_get_abs_string_index\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mIndexError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'index {idx} is out of range'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mIndexError\u001b[0m: index 31 is out of range" ] } ], "source": [] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "a sentimental mess that never rings true .\n" ] }, { "data": { "application/vnd.plotly.v1+json": { "config": { "plotlyServerURL": "https://plot.ly" }, "data": [ { "orientation": "h", "type": "bar", "x": [ 0.7149210223756529, 0.5306234489651611, 0.3787273657087757 ], "y": [ "URLs and hyperlinks within the text", " numerical values and statistical data representations", "keywords and identifiers related to programming and networking concepts" ] } ], "layout": { "height": 500, "margin": { "l": 200 }, "template": { "data": { "bar": [ { "error_x": { "color": "#2a3f5f" }, "error_y": { "color": "#2a3f5f" }, "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "bar" } ], "barpolar": [ { "marker": { "line": { "color": "#E5ECF6", "width": 0.5 }, "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "barpolar" } ], "carpet": [ { "aaxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "baxis": { "endlinecolor": "#2a3f5f", "gridcolor": "white", "linecolor": "white", "minorgridcolor": "white", "startlinecolor": "#2a3f5f" }, "type": "carpet" } ], "choropleth": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "choropleth" } ], "contour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "contour" } ], "contourcarpet": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "contourcarpet" } ], "heatmap": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmap" } ], "heatmapgl": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "heatmapgl" } ], "histogram": [ { "marker": { "pattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 } }, "type": "histogram" } ], "histogram2d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2d" } ], "histogram2dcontour": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "histogram2dcontour" } ], "mesh3d": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "type": "mesh3d" } ], "parcoords": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "parcoords" } ], "pie": [ { "automargin": true, "type": "pie" } ], "scatter": [ { "fillpattern": { "fillmode": "overlay", "size": 10, "solidity": 0.2 }, "type": "scatter" } ], "scatter3d": [ { "line": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatter3d" } ], "scattercarpet": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattercarpet" } ], "scattergeo": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergeo" } ], "scattergl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattergl" } ], "scattermapbox": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scattermapbox" } ], "scatterpolar": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolar" } ], "scatterpolargl": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterpolargl" } ], "scatterternary": [ { "marker": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "type": "scatterternary" } ], "surface": [ { "colorbar": { "outlinewidth": 0, "ticks": "" }, "colorscale": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "type": "surface" } ], "table": [ { "cells": { "fill": { "color": "#EBF0F8" }, "line": { "color": "white" } }, "header": { "fill": { "color": "#C8D4E3" }, "line": { "color": "white" } }, "type": "table" } ] }, "layout": { "annotationdefaults": { "arrowcolor": "#2a3f5f", "arrowhead": 0, "arrowwidth": 1 }, "autotypenumbers": "strict", "coloraxis": { "colorbar": { "outlinewidth": 0, "ticks": "" } }, "colorscale": { "diverging": [ [ 0, "#8e0152" ], [ 0.1, "#c51b7d" ], [ 0.2, "#de77ae" ], [ 0.3, "#f1b6da" ], [ 0.4, "#fde0ef" ], [ 0.5, "#f7f7f7" ], [ 0.6, "#e6f5d0" ], [ 0.7, "#b8e186" ], [ 0.8, "#7fbc41" ], [ 0.9, "#4d9221" ], [ 1, "#276419" ] ], "sequential": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ], "sequentialminus": [ [ 0, "#0d0887" ], [ 0.1111111111111111, "#46039f" ], [ 0.2222222222222222, "#7201a8" ], [ 0.3333333333333333, "#9c179e" ], [ 0.4444444444444444, "#bd3786" ], [ 0.5555555555555556, "#d8576b" ], [ 0.6666666666666666, "#ed7953" ], [ 0.7777777777777778, "#fb9f3a" ], [ 0.8888888888888888, "#fdca26" ], [ 1, "#f0f921" ] ] }, "colorway": [ "#636efa", "#EF553B", "#00cc96", "#ab63fa", "#FFA15A", "#19d3f3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52" ], "font": { "color": "#2a3f5f" }, "geo": { "bgcolor": "white", "lakecolor": "white", "landcolor": "#E5ECF6", "showlakes": true, "showland": true, "subunitcolor": "white" }, "hoverlabel": { "align": "left" }, "hovermode": "closest", "mapbox": { "style": "light" }, "paper_bgcolor": "white", "plot_bgcolor": "#E5ECF6", "polar": { "angularaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "radialaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "scene": { "xaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "yaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" }, "zaxis": { "backgroundcolor": "#E5ECF6", "gridcolor": "white", "gridwidth": 2, "linecolor": "white", "showbackground": true, "ticks": "", "zerolinecolor": "white" } }, "shapedefaults": { "line": { "color": "#2a3f5f" } }, "ternary": { "aaxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "baxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" }, "bgcolor": "#E5ECF6", "caxis": { "gridcolor": "white", "linecolor": "white", "ticks": "" } }, "title": { "x": 0.05 }, "xaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 }, "yaxis": { "automargin": true, "gridcolor": "white", "linecolor": "white", "ticks": "", "title": { "standoff": 15 }, "zerolinecolor": "white", "zerolinewidth": 2 } } }, "title": { "text": "Feature Importance" }, "xaxis": { "title": { "text": "Importance" } }, "yaxis": { "autorange": "reversed", "title": { "text": "Features" } } } }, "text/html": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(text)\n", "fig.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", "target_act = gather_residual_activations(model, layer, inputs)\n", "sae_act = sae.encode(target_act)\n", "\n", "\n", "activated_tokens = sae_act[0:,:,feature]\n", "# max_activation = activated_tokens.max().item()\n", "# activated_tokens /= max_activation\n", "\n", "# activated_tokens = activated_tokens.cpu().detach().numpy()\n", "\n", "# output = []\n", "\n", "# for i, token_id in enumerate(inputs[0, :]):\n", "# token = tokenizer.decode(token_id)\n", "# output.append((token, activated_tokens[0, i]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_feature_iframe(feature):\n", " layer_name = f\"{layer}-gemmascope-res-{width}\"\n", " model_name_neuronpedia = model_name.split(\"/\")[1]\n", "\n", " url = f\"https://www.neuronpedia.org/api/feature/{model_name_neuronpedia}/{layer_name}/{feature}?embed=true\"\n", " html_content = embed_content(url)\n", " return html_content\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "digital-video documentary about stand-up comedians is a great glimpse into a very different world . 1\n" ] } ], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[62.0354, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],\n", " device='cuda:0')" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", "target_act = gather_residual_activations(model, layer, inputs)\n", "sae_act = sae.encode(target_act)\n", "\n", "activated_tokens = sae_act[0:,:,feature]\n", "activated_tokens\n", "# max_activation = activated_tokens.max().item()\n", "# activated_tokens /= max_activation\n", "\n", "# activated_tokens = activated_tokens.cpu().detach().numpy()\n", "\n", "# output = []\n", "\n", "# for i, token_id in enumerate(inputs[0, :]):\n", "# token = tokenizer.decode(token_id)\n", "# output.append((token, activated_tokens[0, i]))\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_highlighted_text(text, feature):\n", "\n", " inputs = tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=True).to(\"cuda\")\n", "\n", " target_act = gather_residual_activations(model, layer, inputs)\n", " sae_act = sae.encode(target_act)\n", "\n", " activated_tokens = sae_act[0:,1:,feature]\n", " max_activation = activated_tokens.max().item()\n", " activated_tokens /= max_activation\n", "\n", " activated_tokens = activated_tokens.cpu().detach().numpy()\n", "\n", " output = []\n", "\n", " for i, token_id in enumerate(inputs[0, 1:]):\n", " token = tokenizer.decode(token_id)\n", " output.append((token, activated_tokens[0, i]))\n", "\n", " return output" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Positive': 0.9071025712081094, 'Negative': 0.09289742879189056}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "0135f3b6c691405ea1d522cad0c66097": { 