## Detoxify LLM outputs using TrustyAI Detoxify and HF SFTTrainer 

## Why use Supervised Fine-Tuning ?
- Train model on specific downstream task, with curated input-output pairs
- First step in model alignment, teaching a model to emulate "correct" behavior
- Prevents catastrophic forgetting

### Steps:
1. Sample inputs or prompts from dataset
2. Labeler demonstrates ideal ouput behavior
3. Train model on inputs and ideal outputs

### Challenges:
- Manual inspection of data is expensive and not scalable

## How can TrustyAI Detoxify make SFT more accessible ?
- Rephrase toxic prompts, guardrailing LLM during training

In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    set_seed
    )
from datasets import load_dataset, load_from_disk
from peft import LoraConfig
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
import numpy as np
import torch
from trustyai.detoxify import TMaRCo

### Load dataset

In [None]:
dataset_name = "allenai/real-toxicity-prompts"
raw_dataset = load_dataset(dataset_name, split="train").flatten()
print(raw_dataset.column_names)

In [None]:
texts = [prompt + cont for prompt, cont in zip(raw_dataset.shuffle(seed=42)["prompt.text"][:5], raw_dataset.shuffle(seed=42)["continuation.text"][:5])]
print(*(texts), sep="\n")

### Load TMaRCo models

In [3]:
tmarco = TMaRCo()
tmarco.load_models(["trustyai/gminus", "trustyai/gplus"])

  return self.fget.__get__(instance, owner)()


### Define helper functions to preprocess data

In [4]:
def preprocess_func(sample):
    # Concatenate prompt and contination text
    sample['text'] = f"Prompt: {sample['prompt.text']}\nContinuation:{sample['continuation.text']}"
    return sample

In [5]:
def tokenize_func(sample):
    return tokenizer(sample["text"], padding="max_length", truncation=True)

In [13]:
block_size = 128
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


In [6]:
def rephrase_func(sample):
    # Calculate disagreement scores
    scores = tmarco.score([sample['text']])
    # Mask tokens with the highest disagremeent scores
    masked_outputs = tmarco.mask([sample['text']], scores=scores, threshold=0.6)
    # Rephrased text by replacing masked tokens
    sample['text'] = tmarco.rephrase([sample['text']], masked_outputs=masked_outputs, expert_weights=[-0.5, 4],combine_original=True)[0]
    return sample

### Train test split

In [7]:
dataset = raw_dataset.train_test_split(test_size=0.2, shuffle=True, seed=42)
train_data = dataset["train"].select(indices=range(0, 1000))
eval_data = dataset["test"].select(indices=range(0, 400))

### Load model and tokenizer

In [8]:
model_id = "facebook/opt-350m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

### Preprocess data

In [9]:
train_ds = train_data.map(preprocess_func, remove_columns=train_data.column_names)
eval_ds = eval_data.map(preprocess_func, remove_columns=eval_data.column_names)

In [14]:
# select samples whose length are less than equal to the mean length of the training set
mean_length = np.mean([len(text) for text in train_ds['text']])
train_ds = train_ds.filter(lambda x: len(x['text']) <= mean_length)

tokenized_train_ds = train_ds.map(tokenize_func, batched=True, remove_columns=train_ds.column_names)
tokenized_eval_ds = eval_ds.map(tokenize_func, batched=True, remove_columns=eval_ds.column_names)

print(f"Size of training set: {len(tokenized_train_ds)}\nSize of evaluation set: {len(tokenized_eval_ds)}")
rephrased_train_ds = train_ds.map(rephrase_func)

Map:   0%|          | 0/557 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Size of training set: 557
Size of evaluation set: 400


In [15]:
tokenized_train_ds = tokenized_train_ds.map(group_texts, batched=True)
tokenized_eval_ds = tokenized_eval_ds.map(group_texts, batched=True)

Map:   0%|          | 0/557 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

In [12]:
train_ds = load_from_disk("../datasets/train_dataset")
rephrased_train_ds = load_from_disk("../datasets/rephrased_train_dataset")

### Compare raw and rephrased texts

In [None]:
for i, text in enumerate(zip(train_ds["text"][:5], rephrased_train_ds["text"][:5])):
    print("##" * 10 + f"Sample {i}" + "##" * 10)
    print(f"Original text: {text[0]}")
    print(" ")
    print(f"Rephrased text: {text[1]}")
    print(" ")

### Fine-tune model on raw input-output pairs

In [16]:
device_map =  {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

In [17]:
model_kwargs = dict(
    torch_dtype="auto",
    use_cache=False, # set to False as we're going to use gradient checkpointing
    device_map=device_map,
)

In [20]:
training_args = TrainingArguments(
    output_dir="../models/opt-350m_CASUAL_LM",
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    learning_rate=1e-04,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine"
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [21]:
trainer = Trainer(
    model=AutoModelForCausalLM.from_pretrained(model_id),
    args=training_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_eval_ds,
    data_collator=data_collator
)

In [None]:
trainer.train()

In [None]:
trainer.save()

In [None]:
torch.cuda.empty_cache()
del trainer

In [14]:
eval_dataset = eval_dataset.select(indices=range(0, 400))
print(f"Size of training set: {len(train_dataset)}\nSize of evaluation set: {len(eval_dataset)}")

Size of training set: 557
Size of evaluation set: 400


In [19]:
train_dataset.save_to_disk("../datasets/train_dataset")
eval_dataset.save_to_disk("../datasets/eval_dataset")
rephrased_train_dataset.save_to_disk("../datasets/rephrased_train_dataset")

Saving the dataset (0/1 shards):   0%|          | 0/557 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/400 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/557 [00:00<?, ? examples/s]

### Model configuration

In [3]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
)

model_kwargs = dict(
    torch_dtype="auto",
    use_cache=False, # set to False as we're going to use gradient checkpointing
    device_map=device_map,
    quantization_config=bnb_config
)

### Model training

In [4]:
from datasets import load_from_disk
rephrased_train_dataset = load_from_disk("../datasets/rephrased_train_dataset")
eval_dataset = load_from_disk("../datasets/eval_dataset/")

In [None]:
peft_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

trainer = SFTTrainer(
    model=model_id,
    model_init_kwargs=model_kwargs,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=rephrased_train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    peft_config=peft_config,
    max_seq_length=min(tokenizer.model_max_length, 512)
)

In [6]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,4.1774,3.438231
2,3.6487,3.326519
3,3.5382,3.323062
4,3.4441,3.339012
5,3.4334,3.329849


TrainOutput(global_step=2785, training_loss=3.6160052588854916, metrics={'train_runtime': 473.0753, 'train_samples_per_second': 5.887, 'train_steps_per_second': 5.887, 'total_flos': 160829875077120.0, 'train_loss': 3.6160052588854916, 'epoch': 5.0})

### Save model

In [7]:
trainer.save_model("../models/opt-350m_DETOXIFY_CAUSAL_LM")

In [8]:
torch.cuda.empty_cache()
del trainer
del model