{ "cells": [ { "cell_type": "markdown", "id": "e4186a59-0fc3-4b9b-a2b1-f7fbd47540ec", "metadata": {}, "source": [ "## Detoxify LLM outputs using TrustyAI Detoxify and HF SFTTrainer " ] }, { "cell_type": "markdown", "id": "9ae7b6fc-c639-4657-b66a-b318abd730ba", "metadata": {}, "source": [ "## Why use Supervised Fine-Tuning ?\n", "- Train model on specific downstream task, with curated input-output pairs\n", "- First step in model alignment, teaching a model to emulate \"correct\" behavior\n", "- Prevents catastrophic forgetting\n", "\n", "### Steps:\n", "1. Sample inputs or prompts from dataset\n", "2. Labeler demonstrates ideal ouput behavior\n", "3. Train model on inputs and ideal outputs\n", "\n", "### Challenges:\n", "- Manual inspection of data is expensive and not scalable\n", "\n", "## How can TrustyAI Detoxify make SFT more accessible ?\n", "- Rephrase toxic prompts, guardrailing LLM during training" ] }, { "cell_type": "code", "execution_count": 1, "id": "8cf1204f-a89e-4b81-8b4f-82c3b2b09994", "metadata": {}, "outputs": [], "source": [ "from transformers import (\n", " AutoTokenizer,\n", " AutoModelForCausalLM,\n", " DataCollatorForLanguageModeling,\n", " BitsAndBytesConfig,\n", " Trainer,\n", " TrainingArguments,\n", " set_seed\n", " )\n", "from datasets import load_dataset, load_from_disk\n", "from peft import LoraConfig\n", "from trl import SFTTrainer\n", "from trl.trainer import ConstantLengthDataset\n", "import numpy as np\n", "import torch\n", "from trustyai.detoxify import TMaRCo" ] }, { "cell_type": "markdown", "id": "8b398ce2-d86e-4e04-9631-7469447bf4b2", "metadata": { "tags": [] }, "source": [ "### Load dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "c009792f-4bed-422a-9f14-151a09aaaddd", "metadata": {}, "outputs": [], "source": [ "dataset_name = \"allenai/real-toxicity-prompts\"\n", "raw_dataset = load_dataset(dataset_name, split=\"train\").flatten()\n", "print(raw_dataset.column_names)" ] }, { "cell_type": "code", "execution_count": null, "id": "fd10e804-b4be-48ff-b38c-65f13f69eddb", "metadata": { "tags": [] }, "outputs": [], "source": [ "texts = [prompt + cont for prompt, cont in zip(raw_dataset.shuffle(seed=42)[\"prompt.text\"][:5], raw_dataset.shuffle(seed=42)[\"continuation.text\"][:5])]\n", "print(*(texts), sep=\"\\n\")" ] }, { "cell_type": "markdown", "id": "4a2e9e31-6224-4cfa-8c5d-33bd2e0e2aa4", "metadata": {}, "source": [ "### Load TMaRCo models" ] }, { "cell_type": "code", "execution_count": 3, "id": "e8abccc6-bce1-42c4-b462-8b8125e34350", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/app-root/lib64/python3.9/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n" ] } ], "source": [ "tmarco = TMaRCo()\n", "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])" ] }, { "cell_type": "markdown", "id": "0fbd9ba2-a0a3-43f3-a17f-45a9631b4530", "metadata": {}, "source": [ "### Define helper functions to preprocess data" ] }, { "cell_type": "code", "execution_count": 4, "id": "10404143-b3a5-4a29-9139-2658ba8bc50c", "metadata": {}, "outputs": [], "source": [ "def preprocess_func(sample):\n", " # Concatenate prompt and contination text\n", " sample['text'] = f\"Prompt: {sample['prompt.text']}\\nContinuation:{sample['continuation.text']}\"\n", " return sample" ] }, { "cell_type": "code", "execution_count": 5, "id": "b396e973-399d-4157-86ab-e659e55f938f", "metadata": { "tags": [] }, "outputs": [], "source": [ "def tokenize_func(sample):\n", " return tokenizer(sample[\"text\"], padding=\"max_length\", truncation=True)" ] }, { "cell_type": "code", "execution_count": 13, "id": "075ff74b-b959-47df-aa20-795d3f1d641d", "metadata": { "tags": [] }, "outputs": [], "source": [ "block_size = 128\n", "def group_texts(examples):\n", " # Concatenate all texts.\n", " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n", " total_length = len(concatenated_examples[list(examples.keys())[0]])\n", " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n", " # customize this part to your needs.\n", " if total_length >= block_size:\n", " total_length = (total_length // block_size) * block_size\n", " # Split by chunks of block_size.\n", " result = {\n", " k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n", " for k, t in concatenated_examples.items()\n", " }\n", " result[\"labels\"] = result[\"input_ids\"].copy()\n", " return result\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "f2ce2a35-3480-4dc0-8b94-91591059cd44", "metadata": { "tags": [] }, "outputs": [], "source": [ "def rephrase_func(sample):\n", " # Calculate disagreement scores\n", " scores = tmarco.score([sample['text']])\n", " # Mask tokens with the highest disagremeent scores\n", " masked_outputs = tmarco.mask([sample['text']], scores=scores, threshold=0.6)\n", " # Rephrased text by replacing masked tokens\n", " sample['text'] = tmarco.rephrase([sample['text']], masked_outputs=masked_outputs, expert_weights=[-0.5, 4],combine_original=True)[0]\n", " return sample" ] }, { "cell_type": "markdown", "id": "b9a6605a-c291-4c64-bc6c-2dbc7fb54b64", "metadata": {}, "source": [ "### Train test split" ] }, { "cell_type": "code", "execution_count": 7, "id": "e1c16957-e212-4060-af88-36df9be4d620", "metadata": {}, "outputs": [], "source": [ "dataset = raw_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)\n", "train_data = dataset[\"train\"].select(indices=range(0, 1000))\n", "eval_data = dataset[\"test\"].select(indices=range(0, 400))" ] }, { "cell_type": "markdown", "id": "ce797bb3-c050-49aa-af72-4fa61e128f89", "metadata": {}, "source": [ "### Load model and tokenizer" ] }, { "cell_type": "code", "execution_count": 8, "id": "b04f3a66-7b28-42a9-a241-6412d7df481a", "metadata": {}, "outputs": [], "source": [ "model_id = \"facebook/opt-350m\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", "tokenizer.pad_token = tokenizer.eos_token\n", "tokenizer.padding_side = \"right\"" ] }, { "cell_type": "markdown", "id": "58416f0c-e630-433d-bb38-d9676fe383d9", "metadata": { "tags": [] }, "source": [ "### Preprocess data" ] }, { "cell_type": "code", "execution_count": 9, "id": "e12bbc75-2dfd-4135-93e4-a7a16611ab04", "metadata": { "tags": [] }, "outputs": [], "source": [ "train_ds = train_data.map(preprocess_func, remove_columns=train_data.column_names)\n", "eval_ds = eval_data.map(preprocess_func, remove_columns=eval_data.column_names)" ] }, { "cell_type": "code", "execution_count": 14, "id": "38b616f4-ffe5-4c7b-aa78-566051d18a20", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dee50cb21205459ca1c080b3fea89f15", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/557 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "08cc8e1b282a47489d57489ea35d551d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/400 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Size of training set: 557\n", "Size of evaluation set: 400\n" ] } ], "source": [ "# select samples whose length are less than equal to the mean length of the training set\n", "mean_length = np.mean([len(text) for text in train_ds['text']])\n", "train_ds = train_ds.filter(lambda x: len(x['text']) <= mean_length)\n", "\n", "tokenized_train_ds = train_ds.map(tokenize_func, batched=True, remove_columns=train_ds.column_names)\n", "tokenized_eval_ds = eval_ds.map(tokenize_func, batched=True, remove_columns=eval_ds.column_names)\n", "\n", "print(f\"Size of training set: {len(tokenized_train_ds)}\\nSize of evaluation set: {len(tokenized_eval_ds)}\")\n", "rephrased_train_ds = train_ds.map(rephrase_func)" ] }, { "cell_type": "code", "execution_count": 15, "id": "aaec5f28-d972-4544-8274-f350ca91706c", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "475c737a3a83412d9cb2b5e7d498886b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/557 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e464d32ca3842599ed53eee9a8fa9bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/400 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tokenized_train_ds = tokenized_train_ds.map(group_texts, batched=True)\n", "tokenized_eval_ds = tokenized_eval_ds.map(group_texts, batched=True)" ] }, { "cell_type": "code", "execution_count": 12, "id": "1884b31f-298e-42c1-8798-cda41f6ca33b", "metadata": { "tags": [] }, "outputs": [], "source": [ "train_ds = load_from_disk(\"../datasets/train_dataset\")\n", "rephrased_train_ds = load_from_disk(\"../datasets/rephrased_train_dataset\")" ] }, { "cell_type": "markdown", "id": "63db1224-a2bd-4bc1-b01b-3bae694b93a1", "metadata": { "tags": [] }, "source": [ "### Compare raw and rephrased texts" ] }, { "cell_type": "code", "execution_count": null, "id": "24d7ffb8-934b-4b90-990e-1c7da125d8df", "metadata": { "tags": [] }, "outputs": [], "source": [ "for i, text in enumerate(zip(train_ds[\"text\"][:5], rephrased_train_ds[\"text\"][:5])):\n", " print(\"##\" * 10 + f\"Sample {i}\" + \"##\" * 10)\n", " print(f\"Original text: {text[0]}\")\n", " print(\" \")\n", " print(f\"Rephrased text: {text[1]}\")\n", " print(\" \")" ] }, { "cell_type": "markdown", "id": "9fe7bcdc-401a-467e-b88b-d0c9d03a4fc0", "metadata": {}, "source": [ "### Fine-tune model on raw input-output pairs" ] }, { "cell_type": "code", "execution_count": 16, "id": "0eefe2bc-8b18-4d2d-8b4f-5587e6d8f741", "metadata": { "tags": [] }, "outputs": [], "source": [ "device_map = {\"\": torch.cuda.current_device()} if torch.cuda.is_available() else None" ] }, { "cell_type": "code", "execution_count": 17, "id": "d27348ed-5798-45e7-9622-19d6ac56e6fb", "metadata": { "tags": [] }, "outputs": [], "source": [ "model_kwargs = dict(\n", " torch_dtype=\"auto\",\n", " use_cache=False, # set to False as we're going to use gradient checkpointing\n", " device_map=device_map,\n", ")" ] }, { "cell_type": "code", "execution_count": 20, "id": "ea4eae17-3dac-456a-b559-182770df35a8", "metadata": { "tags": [] }, "outputs": [], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"../models/opt-350m_CASUAL_LM\",\n", " evaluation_strategy=\"epoch\",\n", " per_device_train_batch_size=1,\n", " per_device_eval_batch_size=1,\n", " num_train_epochs=5,\n", " learning_rate=1e-04,\n", " max_grad_norm=0.3,\n", " warmup_ratio=0.03,\n", " lr_scheduler_type=\"cosine\"\n", ")\n", "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)" ] }, { "cell_type": "code", "execution_count": 21, "id": "ee82e6bd-ec84-4ed7-ad87-a09bdb576773", "metadata": { "tags": [] }, "outputs": [], "source": [ "trainer = Trainer(\n", " model=AutoModelForCausalLM.from_pretrained(model_id),\n", " args=training_args,\n", " train_dataset=tokenized_train_ds,\n", " eval_dataset=tokenized_eval_ds,\n", " data_collator=data_collator\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6048dfd5-979e-4e02-a25e-f5f6873c9d43", "metadata": { "tags": [] }, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "f33eee3b-8592-468c-a65c-5266ae75e83e", "metadata": {}, "outputs": [], "source": [ "trainer.save()" ] }, { "cell_type": "code", "execution_count": null, "id": "9f8ccbbc-8325-4977-b27b-1dfccf55a22c", "metadata": {}, "outputs": [], "source": [ "torch.cuda.empty_cache()\n", "del trainer" ] }, { "cell_type": "code", "execution_count": 14, "id": "bb60e0de-3238-4e50-88f1-0b546fdc6311", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Size of training set: 557\n", "Size of evaluation set: 400\n" ] } ], "source": [ "eval_dataset = eval_dataset.select(indices=range(0, 400))\n", "print(f\"Size of training set: {len(train_dataset)}\\nSize of evaluation set: {len(eval_dataset)}\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "b0040027-c858-425c-b641-d3fe86317566", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "487c6e4efada4e60bbfa41a591d38430", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/1 shards): 0%| | 0/557 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ecaf7bab6db94f61895506a7b6a220bd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/1 shards): 0%| | 0/400 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "882362f7fb4c4df88305a488b093ab34", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/1 shards): 0%| | 0/557 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "train_dataset.save_to_disk(\"../datasets/train_dataset\")\n", "eval_dataset.save_to_disk(\"../datasets/eval_dataset\")\n", "rephrased_train_dataset.save_to_disk(\"../datasets/rephrased_train_dataset\")" ] }, { "cell_type": "markdown", "id": "79f0c8c8-2266-4166-bec9-50fc092e0b3c", "metadata": {}, "source": [ "### Model configuration" ] }, { "cell_type": "code", "execution_count": 3, "id": "2b5ae7be-434d-4c80-90b8-9914a2e26c16", "metadata": {}, "outputs": [], "source": [ "bnb_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_compute_dtype=torch.bfloat16\n", ")\n", "\n", "model_kwargs = dict(\n", " torch_dtype=\"auto\",\n", " use_cache=False, # set to False as we're going to use gradient checkpointing\n", " device_map=device_map,\n", " quantization_config=bnb_config\n", ")" ] }, { "cell_type": "markdown", "id": "ae6bf300-81b1-46f3-9ed3-d49f77c3c110", "metadata": {}, "source": [ "### Model training" ] }, { "cell_type": "code", "execution_count": 4, "id": "a5544e6d-48c3-41bd-866e-8265dcbee52f", "metadata": { "tags": [] }, "outputs": [], "source": [ "from datasets import load_from_disk\n", "rephrased_train_dataset = load_from_disk(\"../datasets/rephrased_train_dataset\")\n", "eval_dataset = load_from_disk(\"../datasets/eval_dataset/\")" ] }, { "cell_type": "code", "execution_count": null, "id": "95be1d2d-aa38-454f-b002-4c53d4b45e21", "metadata": {}, "outputs": [], "source": [ "peft_config = LoraConfig(\n", " r=64,\n", " lora_alpha=16,\n", " lora_dropout=0.1,\n", " bias=\"none\",\n", " task_type=\"CAUSAL_LM\",\n", " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n", ")\n", "\n", "trainer = SFTTrainer(\n", " model=model_id,\n", " model_init_kwargs=model_kwargs,\n", " tokenizer=tokenizer,\n", " args=training_args,\n", " train_dataset=rephrased_train_dataset,\n", " eval_dataset=eval_dataset,\n", " dataset_text_field=\"text\",\n", " peft_config=peft_config,\n", " max_seq_length=min(tokenizer.model_max_length, 512)\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "f22feb53-4d2a-41c7-98c7-43288b17d426", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
1 | \n", "4.177400 | \n", "3.438231 | \n", "
2 | \n", "3.648700 | \n", "3.326519 | \n", "
3 | \n", "3.538200 | \n", "3.323062 | \n", "
4 | \n", "3.444100 | \n", "3.339012 | \n", "
5 | \n", "3.433400 | \n", "3.329849 | \n", "
"
],
"text/plain": [
"