{ "cells": [ { "cell_type": "markdown", "id": "e4186a59-0fc3-4b9b-a2b1-f7fbd47540ec", "metadata": {}, "source": [ "## Detoxify LLM outputs using TrustyAI Detoxify and HF SFTTrainer " ] }, { "cell_type": "markdown", "id": "9ae7b6fc-c639-4657-b66a-b318abd730ba", "metadata": {}, "source": [ "## Why use Supervised Fine-Tuning ?\n", "- Train model on specific downstream task, with curated input-output pairs\n", "- First step in model alignment, teaching a model to emulate \"correct\" behavior\n", "- Prevents catastrophic forgetting\n", "\n", "### Steps:\n", "1. Sample inputs or prompts from dataset\n", "2. Labeler demonstrates ideal ouput behavior\n", "3. Train model on inputs and ideal outputs\n", "\n", "### Challenges:\n", "- Manual inspection of data is expensive and not scalable\n", "\n", "## How can TrustyAI Detoxify make SFT more accessible ?\n", "- Rephrase toxic prompts, guardrailing LLM during training" ] }, { "cell_type": "code", "execution_count": 1, "id": "8cf1204f-a89e-4b81-8b4f-82c3b2b09994", "metadata": {}, "outputs": [], "source": [ "from transformers import (\n", " AutoTokenizer,\n", " AutoModelForCausalLM,\n", " DataCollatorForLanguageModeling,\n", " BitsAndBytesConfig,\n", " Trainer,\n", " TrainingArguments,\n", " set_seed\n", " )\n", "from datasets import load_dataset, load_from_disk\n", "from peft import LoraConfig\n", "from trl import SFTTrainer\n", "from trl.trainer import ConstantLengthDataset\n", "import numpy as np\n", "import torch\n", "from trustyai.detoxify import TMaRCo" ] }, { "cell_type": "markdown", "id": "8b398ce2-d86e-4e04-9631-7469447bf4b2", "metadata": { "tags": [] }, "source": [ "### Load dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "c009792f-4bed-422a-9f14-151a09aaaddd", "metadata": {}, "outputs": [], "source": [ "dataset_name = \"allenai/real-toxicity-prompts\"\n", "raw_dataset = load_dataset(dataset_name, split=\"train\").flatten()\n", "print(raw_dataset.column_names)" ] }, { "cell_type": "code", "execution_count": null, "id": "fd10e804-b4be-48ff-b38c-65f13f69eddb", "metadata": { "tags": [] }, "outputs": [], "source": [ "texts = [prompt + cont for prompt, cont in zip(raw_dataset.shuffle(seed=42)[\"prompt.text\"][:5], raw_dataset.shuffle(seed=42)[\"continuation.text\"][:5])]\n", "print(*(texts), sep=\"\\n\")" ] }, { "cell_type": "markdown", "id": "4a2e9e31-6224-4cfa-8c5d-33bd2e0e2aa4", "metadata": {}, "source": [ "### Load TMaRCo models" ] }, { "cell_type": "code", "execution_count": 3, "id": "e8abccc6-bce1-42c4-b462-8b8125e34350", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/app-root/lib64/python3.9/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n" ] } ], "source": [ "tmarco = TMaRCo()\n", "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])" ] }, { "cell_type": "markdown", "id": "0fbd9ba2-a0a3-43f3-a17f-45a9631b4530", "metadata": {}, "source": [ "### Define helper functions to preprocess data" ] }, { "cell_type": "code", "execution_count": 4, "id": "10404143-b3a5-4a29-9139-2658ba8bc50c", "metadata": {}, "outputs": [], "source": [ "def preprocess_func(sample):\n", " # Concatenate prompt and contination text\n", " sample['text'] = f\"Prompt: {sample['prompt.text']}\\nContinuation:{sample['continuation.text']}\"\n", " return sample" ] }, { "cell_type": "code", "execution_count": 5, "id": "b396e973-399d-4157-86ab-e659e55f938f", "metadata": { "tags": [] }, "outputs": [], "source": [ "def tokenize_func(sample):\n", " return tokenizer(sample[\"text\"], padding=\"max_length\", truncation=True)" ] }, { "cell_type": "code", "execution_count": 13, "id": "075ff74b-b959-47df-aa20-795d3f1d641d", "metadata": { "tags": [] }, "outputs": [], "source": [ "block_size = 128\n", "def group_texts(examples):\n", " # Concatenate all texts.\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", " # customize this part to your needs.\n", " if total_length >= block_size:\n", " total_length = (total_length // block_size) * block_size\n", " # Split by chunks of block_size.\n", " result = {\n", " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "f2ce2a35-3480-4dc0-8b94-91591059cd44", "metadata": { "tags": [] }, "outputs": [], "source": [ "def rephrase_func(sample):\n", " # Calculate disagreement scores\n", " scores = tmarco.score([sample['text']])\n", " # Mask tokens with the highest disagremeent scores\n", " masked_outputs = tmarco.mask([sample['text']], scores=scores, threshold=0.6)\n", " # Rephrased text by replacing masked tokens\n", " sample['text'] = tmarco.rephrase([sample['text']], masked_outputs=masked_outputs, expert_weights=[-0.5, 4],combine_original=True)[0]\n", " return sample" ] }, { "cell_type": "markdown", "id": "b9a6605a-c291-4c64-bc6c-2dbc7fb54b64", "metadata": {}, "source": [ "### Train test split" ] }, { "cell_type": "code", "execution_count": 7, "id": "e1c16957-e212-4060-af88-36df9be4d620", "metadata": {}, "outputs": [], "source": [ "dataset = raw_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n", "train_data = dataset[\"train\"].select(indices=range(0, 1000))\n", "eval_data = dataset[\"test\"].select(indices=range(0, 400))" ] }, { "cell_type": "markdown", "id": "ce797bb3-c050-49aa-af72-4fa61e128f89", "metadata": {}, "source": [ "### Load model and tokenizer" ] }, { "cell_type": "code", "execution_count": 8, "id": "b04f3a66-7b28-42a9-a241-6412d7df481a", "metadata": {}, "outputs": [], "source": [ "model_id = \"facebook/opt-350m\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"" ] }, { "cell_type": "markdown", "id": "58416f0c-e630-433d-bb38-d9676fe383d9", "metadata": { "tags": [] }, "source": [ "### Preprocess data" ] }, { "cell_type": "code", "execution_count": 9, "id": "e12bbc75-2dfd-4135-93e4-a7a16611ab04", "metadata": { "tags": [] }, "outputs": [], "source": [ "train_ds = train_data.map(preprocess_func, remove_columns=train_data.column_names)\n", "eval_ds = eval_data.map(preprocess_func, remove_columns=eval_data.column_names)" ] }, { "cell_type": "code", "execution_count": 14, "id": "38b616f4-ffe5-4c7b-aa78-566051d18a20", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dee50cb21205459ca1c080b3fea89f15", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/557 [00:00\n", " \n", " \n", " [2785/2785 07:52, Epoch 5/5]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
14.1774003.438231
23.6487003.326519
33.5382003.323062
43.4441003.339012
53.4334003.329849

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=2785, training_loss=3.6160052588854916, metrics={'train_runtime': 473.0753, 'train_samples_per_second': 5.887, 'train_steps_per_second': 5.887, 'total_flos': 160829875077120.0, 'train_loss': 3.6160052588854916, 'epoch': 5.0})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "markdown", "id": "d8996594-86d4-4d20-b23b-5928ed3c27b9", "metadata": {}, "source": [ "### Save model" ] }, { "cell_type": "code", "execution_count": 7, "id": "fac9a7f6-1bbf-4992-81ce-9095d07f524c", "metadata": {}, "outputs": [], "source": [ "trainer.save_model(\"../models/opt-350m_DETOXIFY_CAUSAL_LM\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "0e0c04b2-6986-40b5-82c8-69121eb07768", "metadata": { "tags": [] }, "outputs": [], "source": [ "torch.cuda.empty_cache()\n", "del trainer\n", "del model" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 5 }