{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Saving PruneBERT\n", "\n", "\n", "This notebook aims at showcasing how we can leverage standard tools to save (and load) an extremely sparse model fine-pruned with [movement pruning](https://arxiv.org/abs/2005.07683) (or any other unstructured pruning mehtod).\n", "\n", "In this example, we used BERT (base-uncased, but the procedure described here is not specific to BERT and can be applied to a large variety of models.\n", "\n", "We first obtain an extremely sparse model by fine-pruning with movement pruning on SQuAD v1.1. We then used the following combination of standard tools:\n", "- We reduce the precision of the model with Int8 dynamic quantization using [PyTorch implementation](https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html). We only quantized the Fully Connected Layers.\n", "- Sparse quantized matrices are converted into the [Compressed Sparse Row format](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html).\n", "- We use HDF5 with `gzip` compression to store the weights.\n", "\n", "We experiment with a question answering model with only 6% of total remaining weights in the encoder (previously obtained with movement pruning). **We are able to reduce the memory size of the encoder from 340MB (original dense BERT) to 11MB**, which fits on a [91' floppy disk](https://en.wikipedia.org/wiki/Floptical)!\n", "\n", "\n", "\n", "*Note: this notebook is compatible with `torch>=1.5.0` If you are using, `torch==1.4.0`, please refer to [this previous version of the notebook](https://github.com/huggingface/transformers/commit/b11386e158e86e62d4041eabd86d044cd1695737).*" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Includes\n", "\n", "import h5py\n", "import os\n", "import json\n", "from collections import OrderedDict\n", "\n", "from scipy import sparse\n", "import numpy as np\n", "\n", "import torch\n", "from torch import nn\n", "\n", "from transformers import *\n", "\n", "os.chdir(\"../../\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dynamic quantization induces little or no loss of performance while significantly reducing the memory footprint." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load fine-pruned model and quantize the model\n", "\n", "model = BertForQuestionAnswering.from_pretrained(\"huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad\")\n", "model.to(\"cpu\")\n", "\n", "quantized_model = torch.quantization.quantize_dynamic(\n", " model=model,\n", " qconfig_spec={\n", " nn.Linear: torch.quantization.default_dynamic_qconfig,\n", " },\n", " dtype=torch.qint8,\n", ")\n", "# print(quantized_model)\n", "\n", "qtz_st = quantized_model.state_dict()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Saving the original (encoder + classifier) in the standard torch.save format\n", "\n", "dense_st = {\n", " name: param for name, param in model.state_dict().items() if \"embedding\" not in name and \"pooler\" not in name\n", "}\n", "torch.save(\n", " dense_st,\n", " \"dbg/dense_squad.pt\",\n", ")\n", "dense_mb_size = os.path.getsize(\"dbg/dense_squad.pt\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Decompose quantization for bert.encoder.layer.0.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.0.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.0.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.0.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.0.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.0.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.1.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.1.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.1.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.1.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.1.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.1.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.2.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.2.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.2.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.2.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.2.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.2.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.3.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.3.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.3.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.3.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.3.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.3.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.4.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.4.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.4.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.4.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.4.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.4.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.5.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.5.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.5.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.5.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.5.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.5.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.6.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.6.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.6.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.6.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.6.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.6.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.7.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.7.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.7.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.7.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.7.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.7.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.8.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.8.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.8.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.8.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.8.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.8.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.9.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.9.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.9.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.9.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.9.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.9.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.10.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.10.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.10.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.10.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.10.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.10.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.11.attention.self.query._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.11.attention.self.key._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.11.attention.self.value._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.11.attention.output.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.11.intermediate.dense._packed_params.weight\n", "Decompose quantization for bert.encoder.layer.11.output.dense._packed_params.weight\n", "Decompose quantization for bert.pooler.dense._packed_params.weight\n", "Decompose quantization for qa_outputs._packed_params.weight\n" ] } ], "source": [ "# Elementary representation: we decompose the quantized tensors into (scale, zero_point, int_repr).\n", "# See https://pytorch.org/docs/stable/quantization.html\n", "\n", "# We further leverage the fact that int_repr is sparse matrix to optimize the storage: we decompose int_repr into\n", "# its CSR representation (data, indptr, indices).\n", "\n", "elementary_qtz_st = {}\n", "for name, param in qtz_st.items():\n", " if \"dtype\" not in name and param.is_quantized:\n", " print(\"Decompose quantization for\", name)\n", " # We need to extract the scale, the zero_point and the int_repr for the quantized tensor and modules\n", " scale = param.q_scale() # torch.tensor(1,) - float32\n", " zero_point = param.q_zero_point() # torch.tensor(1,) - int32\n", " elementary_qtz_st[f\"{name}.scale\"] = scale\n", " elementary_qtz_st[f\"{name}.zero_point\"] = zero_point\n", "\n", " # We assume the int_repr is sparse and compute its CSR representation\n", " # Only the FCs in the encoder are actually sparse\n", " int_repr = param.int_repr() # torch.tensor(nb_rows, nb_columns) - int8\n", " int_repr_cs = sparse.csr_matrix(int_repr) # scipy.sparse.csr.csr_matrix\n", "\n", " elementary_qtz_st[f\"{name}.int_repr.data\"] = int_repr_cs.data # np.array int8\n", " elementary_qtz_st[f\"{name}.int_repr.indptr\"] = int_repr_cs.indptr # np.array int32\n", " assert max(int_repr_cs.indices) < 65535 # If not, we shall fall back to int32\n", " elementary_qtz_st[f\"{name}.int_repr.indices\"] = np.uint16(int_repr_cs.indices) # np.array uint16\n", " elementary_qtz_st[f\"{name}.int_repr.shape\"] = int_repr_cs.shape # tuple(int, int)\n", " else:\n", " elementary_qtz_st[name] = param" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Create mapping from torch.dtype to string description (we could also used an int8 instead of string)\n", "str_2_dtype = {\"qint8\": torch.qint8}\n", "dtype_2_str = {torch.qint8: \"qint8\"}" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Encoder Size (MB) - Sparse & Quantized - `torch.save`: 21.29\n" ] } ], "source": [ "# Saving the pruned (encoder + classifier) in the standard torch.save format\n", "\n", "dense_optimized_st = {\n", " name: param for name, param in elementary_qtz_st.items() if \"embedding\" not in name and \"pooler\" not in name\n", "}\n", "torch.save(\n", " dense_optimized_st,\n", " \"dbg/dense_squad_optimized.pt\",\n", ")\n", "print(\n", " \"Encoder Size (MB) - Sparse & Quantized - `torch.save`:\",\n", " round(os.path.getsize(\"dbg/dense_squad_optimized.pt\") / 1e6, 2),\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Skip bert.embeddings.word_embeddings.weight\n", "Skip bert.embeddings.position_embeddings.weight\n", "Skip bert.embeddings.token_type_embeddings.weight\n", "Skip bert.embeddings.LayerNorm.weight\n", "Skip bert.embeddings.LayerNorm.bias\n", "Skip bert.pooler.dense.scale\n", "Skip bert.pooler.dense.zero_point\n", "Skip bert.pooler.dense._packed_params.weight.scale\n", "Skip bert.pooler.dense._packed_params.weight.zero_point\n", "Skip bert.pooler.dense._packed_params.weight.int_repr.data\n", "Skip bert.pooler.dense._packed_params.weight.int_repr.indptr\n", "Skip bert.pooler.dense._packed_params.weight.int_repr.indices\n", "Skip bert.pooler.dense._packed_params.weight.int_repr.shape\n", "Skip bert.pooler.dense._packed_params.bias\n", "Skip bert.pooler.dense._packed_params.dtype\n", "\n", "Encoder Size (MB) - Dense: 340.26\n", "Encoder Size (MB) - Sparse & Quantized: 11.28\n" ] } ], "source": [ "# Save the decomposed state_dict with an HDF5 file\n", "# Saving only the encoder + QA Head\n", "\n", "with h5py.File(\"dbg/squad_sparse.h5\", \"w\") as hf:\n", " for name, param in elementary_qtz_st.items():\n", " if \"embedding\" in name:\n", " print(f\"Skip {name}\")\n", " continue\n", "\n", " if \"pooler\" in name:\n", " print(f\"Skip {name}\")\n", " continue\n", "\n", " if type(param) == torch.Tensor:\n", " if param.numel() == 1:\n", " # module scale\n", " # module zero_point\n", " hf.attrs[name] = param\n", " continue\n", "\n", " if param.requires_grad:\n", " # LayerNorm\n", " param = param.detach().numpy()\n", " hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n", "\n", " elif type(param) == float or type(param) == int or type(param) == tuple:\n", " # float - tensor _packed_params.weight.scale\n", " # int - tensor _packed_params.weight.zero_point\n", " # tuple - tensor _packed_params.weight.shape\n", " hf.attrs[name] = param\n", "\n", " elif type(param) == torch.dtype:\n", " # dtype - tensor _packed_params.dtype\n", " hf.attrs[name] = dtype_2_str[param]\n", "\n", " else:\n", " hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n", "\n", "\n", "with open(\"dbg/metadata.json\", \"w\") as f:\n", " f.write(json.dumps(qtz_st._metadata))\n", "\n", "size = os.path.getsize(\"dbg/squad_sparse.h5\") + os.path.getsize(\"dbg/metadata.json\")\n", "print(\"\")\n", "print(\"Encoder Size (MB) - Dense: \", round(dense_mb_size / 1e6, 2))\n", "print(\"Encoder Size (MB) - Sparse & Quantized:\", round(size / 1e6, 2))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Size (MB): 99.41\n" ] } ], "source": [ "# Save the decomposed state_dict to HDF5 storage\n", "# Save everything in the architecutre (embedding + encoder + QA Head)\n", "\n", "with h5py.File(\"dbg/squad_sparse_with_embs.h5\", \"w\") as hf:\n", " for name, param in elementary_qtz_st.items():\n", " # if \"embedding\" in name:\n", " # print(f\"Skip {name}\")\n", " # continue\n", "\n", " # if \"pooler\" in name:\n", " # print(f\"Skip {name}\")\n", " # continue\n", "\n", " if type(param) == torch.Tensor:\n", " if param.numel() == 1:\n", " # module scale\n", " # module zero_point\n", " hf.attrs[name] = param\n", " continue\n", "\n", " if param.requires_grad:\n", " # LayerNorm\n", " param = param.detach().numpy()\n", " hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n", "\n", " elif type(param) == float or type(param) == int or type(param) == tuple:\n", " # float - tensor _packed_params.weight.scale\n", " # int - tensor _packed_params.weight.zero_point\n", " # tuple - tensor _packed_params.weight.shape\n", " hf.attrs[name] = param\n", "\n", " elif type(param) == torch.dtype:\n", " # dtype - tensor _packed_params.dtype\n", " hf.attrs[name] = dtype_2_str[param]\n", "\n", " else:\n", " hf.create_dataset(name, data=param, compression=\"gzip\", compression_opts=9)\n", "\n", "\n", "with open(\"dbg/metadata.json\", \"w\") as f:\n", " f.write(json.dumps(qtz_st._metadata))\n", "\n", "size = os.path.getsize(\"dbg/squad_sparse_with_embs.h5\") + os.path.getsize(\"dbg/metadata.json\")\n", "print(\"\\nSize (MB):\", round(size / 1e6, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Reconstruct the elementary state dict\n", "\n", "reconstructed_elementary_qtz_st = {}\n", "\n", "hf = h5py.File(\"dbg/squad_sparse_with_embs.h5\", \"r\")\n", "\n", "for attr_name, attr_param in hf.attrs.items():\n", " if \"shape\" in attr_name:\n", " attr_param = tuple(attr_param)\n", " elif \".scale\" in attr_name:\n", " if \"_packed_params\" in attr_name:\n", " attr_param = float(attr_param)\n", " else:\n", " attr_param = torch.tensor(attr_param)\n", " elif \".zero_point\" in attr_name:\n", " if \"_packed_params\" in attr_name:\n", " attr_param = int(attr_param)\n", " else:\n", " attr_param = torch.tensor(attr_param)\n", " elif \".dtype\" in attr_name:\n", " attr_param = str_2_dtype[attr_param]\n", " reconstructed_elementary_qtz_st[attr_name] = attr_param\n", " # print(f\"Unpack {attr_name}\")\n", "\n", "# Get the tensors/arrays\n", "for data_name, data_param in hf.items():\n", " if \"LayerNorm\" in data_name or \"_packed_params.bias\" in data_name:\n", " reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n", " elif \"embedding\" in data_name:\n", " reconstructed_elementary_qtz_st[data_name] = torch.from_numpy(np.array(data_param))\n", " else: # _packed_params.weight.int_repr.data, _packed_params.weight.int_repr.indices and _packed_params.weight.int_repr.indptr\n", " data_param = np.array(data_param)\n", " if \"indices\" in data_name:\n", " data_param = np.array(data_param, dtype=np.int32)\n", " reconstructed_elementary_qtz_st[data_name] = data_param\n", " # print(f\"Unpack {data_name}\")\n", "\n", "\n", "hf.close()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Sanity checks\n", "\n", "for name, param in reconstructed_elementary_qtz_st.items():\n", " assert name in elementary_qtz_st\n", "for name, param in elementary_qtz_st.items():\n", " assert name in reconstructed_elementary_qtz_st, name\n", "\n", "for name, param in reconstructed_elementary_qtz_st.items():\n", " assert type(param) == type(elementary_qtz_st[name]), name\n", " if type(param) == torch.Tensor:\n", " assert torch.all(torch.eq(param, elementary_qtz_st[name])), name\n", " elif type(param) == np.ndarray:\n", " assert (param == elementary_qtz_st[name]).all(), name\n", " else:\n", " assert param == elementary_qtz_st[name], name" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Re-assemble the sparse int_repr from the CSR format\n", "\n", "reconstructed_qtz_st = {}\n", "\n", "for name, param in reconstructed_elementary_qtz_st.items():\n", " if \"weight.int_repr.indptr\" in name:\n", " prefix_ = name[:-16]\n", " data = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.data\"]\n", " indptr = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indptr\"]\n", " indices = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.indices\"]\n", " shape = reconstructed_elementary_qtz_st[f\"{prefix_}.int_repr.shape\"]\n", "\n", " int_repr = sparse.csr_matrix(arg1=(data, indices, indptr), shape=shape)\n", " int_repr = torch.tensor(int_repr.todense())\n", "\n", " scale = reconstructed_elementary_qtz_st[f\"{prefix_}.scale\"]\n", " zero_point = reconstructed_elementary_qtz_st[f\"{prefix_}.zero_point\"]\n", " weight = torch._make_per_tensor_quantized_tensor(int_repr, scale, zero_point)\n", "\n", " reconstructed_qtz_st[f\"{prefix_}\"] = weight\n", " elif (\n", " \"int_repr.data\" in name\n", " or \"int_repr.shape\" in name\n", " or \"int_repr.indices\" in name\n", " or \"weight.scale\" in name\n", " or \"weight.zero_point\" in name\n", " ):\n", " continue\n", " else:\n", " reconstructed_qtz_st[name] = param" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Sanity checks\n", "\n", "for name, param in reconstructed_qtz_st.items():\n", " assert name in qtz_st\n", "for name, param in qtz_st.items():\n", " assert name in reconstructed_qtz_st, name\n", "\n", "for name, param in reconstructed_qtz_st.items():\n", " assert type(param) == type(qtz_st[name]), name\n", " if type(param) == torch.Tensor:\n", " assert torch.all(torch.eq(param, qtz_st[name])), name\n", " elif type(param) == np.ndarray:\n", " assert (param == qtz_st[name]).all(), name\n", " else:\n", " assert param == qtz_st[name], name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sanity checks" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the re-constructed state dict into a model\n", "\n", "dummy_model = BertForQuestionAnswering.from_pretrained(\"bert-base-uncased\")\n", "dummy_model.to(\"cpu\")\n", "\n", "reconstructed_qtz_model = torch.quantization.quantize_dynamic(\n", " model=dummy_model,\n", " qconfig_spec=None,\n", " dtype=torch.qint8,\n", ")\n", "\n", "reconstructed_qtz_st = OrderedDict(reconstructed_qtz_st)\n", "with open(\"dbg/metadata.json\", \"r\") as read_file:\n", " metadata = json.loads(read_file.read())\n", "reconstructed_qtz_st._metadata = metadata\n", "\n", "reconstructed_qtz_model.load_state_dict(reconstructed_qtz_st)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sanity check passed\n" ] } ], "source": [ "# Sanity checks on the infernce\n", "\n", "N = 32\n", "\n", "for _ in range(25):\n", " inputs = torch.randint(low=0, high=30000, size=(N, 128))\n", " mask = torch.ones(size=(N, 128))\n", "\n", " y_reconstructed = reconstructed_qtz_model(input_ids=inputs, attention_mask=mask)[0]\n", " y = quantized_model(input_ids=inputs, attention_mask=mask)[0]\n", "\n", " assert torch.all(torch.eq(y, y_reconstructed))\n", "print(\"Sanity check passed\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 4 }