{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ODbwgRIAxdCh"
},
"source": [
"# Gemma Scope Tutorial\n",
"\n",
"This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a \"microscope\" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!\n",
"\n",
"**Learn more:**\n",
"* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).\n",
"* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).\n",
"* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gemma Scope Tutorial\n",
"\n",
"This is a barebones tutorial on how to use [Gemma Scope](https://huggingface.co/google/gemma-scope), Google DeepMind's suite of Sparse Autoencoders (SAEs) on every layer and sublayer of Gemma 2 2B and 9B. Sparse Autoencoders are an interpretability tool that act like a \"microscope\" on language model activations. They let us zoom in on dense, compressed activations, and expand them to a larger but sparser and seemingly more interpretable form, which can be a very useful tool when doing interpretability research!\n",
"\n",
"**Learn more:**\n",
"* If you want to learn about Gemma Scope without writing any code, check out [this interactive demo](https://neuronpedia.org/gemma-scope) courtesy of [Neuronpedia](https://neuronpedia.org).\n",
"* For an overview of Gemma Scope check out [the blog post](https://deepmind.google/discover/blog/gemma-scope-helping-the-safety-community-shed-light-on-the-inner-workings-of-language-models).\n",
"* See [the technical report](https://storage.googleapis.com/gemma-scope/gemma-scope-report.pdf) for the technical details\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rB2BasaDOm_t"
},
"source": [
"\n",
"For illustrative purposes, we begin with a lightweight tutorial that uses as few libraries as possible to outline how Gemma Scope works, and what Sparse Autoencoders are doing. This is deliberately a fairly minimalist tutorial, designed to make clear what is actually going on, but does not model research best practices.\n",
"\n",
"For any serious research with Gemma Scope, **we recommend using the [SAELens](https://jbloomaus.github.io/SAELens/) and [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) libraries**, see [this tutorial](https://colab.research.google.com/github/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb) on how to use [SAELens](https://jbloomaus.github.io/SAELens/) in practice.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RvDc2KCO9DYS"
},
"source": [
"## Loading the Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fB9bB3mJ8R1H"
},
"source": [
"First, let's load the model:\n",
"\n",
"For simplicity we do this straight from [HuggingFace transformers](https://huggingface.co/docs/transformers/en/index), rather than using an interpretability focused library like [TransformerLens](https://transformerlensorg.github.io/TransformerLens/) or [nnsight](https://nnsight.net/)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "nOBcV4om7mrT"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ffea7bc53a17446cacb9a35ae3adc0a1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some parameters are on the meta device device because they were offloaded to the cpu.\n"
]
}
],
"source": [
"from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
"from huggingface_hub import hf_hub_download, notebook_login\n",
"import numpy as np\n",
"import torch\n",
"\n",
"torch.set_grad_enabled(False) # avoid blowing up mem\n",
"\n",
"params = {\n",
" \"model_name\" : \"google/gemma-2-9b-it\",\n",
" \"width\" : \"16k\",\n",
" \"layer\" : 31,\n",
" \"l0\" : 76,\n",
" \"sae_repo_id\": \"google/gemma-scope-9b-it-res\",\n",
" \"filename\" : \"layer_31/width_16k/average_l0_76/params.npz\"\n",
"}\n",
"\n",
"# params = {\n",
"# \"model_name\" : \"google/gemma-2-2b\",\n",
"# \"width\" : \"16k\",\n",
"# \"layer\" : 23,\n",
"# \"l0\" : 74,\n",
"# \"sae_repo_id\": \"google/gemma-scope-2b-pt-res\",\n",
"# \"filename\" : \"layer_23/width_16k/average_l0_74/params.npz\"\n",
"# }\n",
"\n",
"model_name = params[\"model_name\"]\n",
"width = params[\"width\"]\n",
"layer = params[\"layer\"]\n",
"l0 = params[\"l0\"]\n",
"sae_repo_id = params[\"sae_repo_id\"]\n",
"filename = params[\"filename\"]\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" device_map='auto',\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"filename = f\"layer_{layer}/width_{width}/average_l0_{l0}/params.npz\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8Q6sQSaAN7T7"
},
"source": [
"We load Gemma 2 2B, the smallest model that Gemma Scope works for. We load the base model, not the chat model, since that's where our SAEs are trained. Though the SAEs seem to transfer OK to these models. First, you'll need to authenticate with huggingface in order to download the model weights"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MZkgvglU9GdW"
},
"source": [
"Now we've loaded the model, let's try running it! We give it the prompt \"Would you be able to travel through time using a wormhole?\" and print the generated output"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qZECwzKi9dGv"
},
"source": [
"## Loading a Sparse Autoencoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wQSE4K5KVmGY"
},
"source": [
"OK, so we have got Gemma 2 loaded and can sample from it to get sensible stuff. Now, let's load one of our SAEs.\n",
"\n",
"GemmaScope actually contains over four hundred SAEs, but for now we'll just load one on the residual stream at the end of layer 20 (of 26, note that layers start at 0 so this is the 21st layer. This is a fairly late layer, so the model should have time to find more abstract concepts!).\n",
"\n",
"See [the final section](https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp?authuser=2#scrollTo=E7zjkVseLSPp) for more information on how to load all the other SAEs in Gemma Scope\n",
"\n",
"What is the residual stream?
\n",
"\n",
"Transformers have skip connections, which means that the output of each block is the output of each sublayer *plus* the input to the block. This means that each sublayer (attention or MLP) actually only has a fairly small effect on the output of the block, since most of it comes from all the earlier layers. We call the output of a block (including skip connections) the **residual stream**.\n",
"\n",
"Everything communicated from earlier layers to later layers must go via the residual stream, so it acts as a \"bottleneck\" in the transformer, essentially capturing everything the model has \"thought\" so far. This means it is often a natural thing to study, since it will contain everything important going on in the model.\n",
"
\n", " | 0 | \n", "1 | \n", "2 | \n", "3 | \n", "4 | \n", "5 | \n", "6 | \n", "7 | \n", "8 | \n", "9 | \n", "... | \n", "16375 | \n", "16376 | \n", "16377 | \n", "16378 | \n", "16379 | \n", "16380 | \n", "16381 | \n", "16382 | \n", "16383 | \n", "label | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "
1 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0 | \n", "
2 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "
3 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0 | \n", "
4 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
8525 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "
8526 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "
8527 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1 | \n", "
8528 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "
8529 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "... | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "0 | \n", "
8530 rows × 16385 columns
\n", "