Train custom AI models with the trainer API and adapt them to 🤗

Community Article Published June 29, 2024

Open In Colab

Introduction

Training LLMs require massive compute power and multiple GPUs, and sharing/version controlling these AI models tends to be tricky.

But it is now easier than ever to push and load your weights directly from a free open-source and reliable machine learning platform leveraging the ModelHubMixin classes.

As for why you need to adapt your model with the trainer API, it's because it allows your model to be trained with almost no training script at all. It even comes with distributed training compatibility by default.

Setup and data processing

First, let's download our dependencies

pip install -q datasets evaluate accelerate "huggingface_hub>=0.22"

now let's log in with our HF writing token

from huggingface_hub import notebook_login
notebook_login()

now let's start coding

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from datasets import load_dataset, Image
from transformers import DefaultDataCollator, TrainingArguments, Trainer
from huggingface_hub import PyTorchModelHubMixin
import evaluate
import numpy as np

original dataset can be found here

dataset = load_dataset("mnist")
# convert the "image" column to a pillow image
dataset = dataset.cast_column("image", Image())

IMPORTANT

your target column needs to be named labels

else you need to pass your label_names in the training arguments

dataset = dataset.rename_column("label","labels")
# convert to Pytorch tensors
transform = transforms.Compose([transforms.ToTensor()])
def to_pt(batch):
    batch["image"] = [transform(image.convert("RGB")) for image in batch["image"]]
    return batch
train = dataset["train"].with_transform(to_pt)
test = dataset["test"].with_transform(to_pt)

unlike dataset.map, dataset.with_transform only applies the transformation when data is accessed, you can read more about it in the docs.

unfortunately, it's not compatible with a streaming dataset, if this interests you can contribute to this at their github repo

Model

Adapt your model to huggingface

to make your model compatible with huggingface all you need to do is inherit from the appropriate class, yuppp, that is all. if you want to spice your README with some more optional parameters feel free to pass them on here

you only need to inherit from one of the Mixin classes, everything else is optional

class BasicNet(nn.Module,PyTorchModelHubMixin,tags=["image-classification"]):
  (...)

the way the mixin classes work is that they take the parameters used in the __init__ method and store them in a config.json file. when calling the push_to_hub method we push both the model weights and the config.json to the hub. when using from_pretrained we download the model weights and the config file, instantiate the model using the parameters stored in the config file, and inject our weights.

TLDR;

  • they add 3 methods to your model similar to how transformers library work which are
    • save_pretrained (saves weights locally)
    • from_pretrained (loads and initializes the model either from the hub or locally)
    • push_to_hub (pushes the weights and the config to the hub)
  • you can pass in other optional metadata to make your README file stand

Trainer API compatibility

now that we have that underlined let's pass to the part where we make our model compatible with the trainer API.

this heavily relies on the forward method

The forward method requires a parameter named labels else we need to use the label_names in the TrainingArguments

class BasicNet(nn.Module,PyTorchModelHubMixin,tags=["image-classification"]):
    def __init__(self,channels):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss()
        self.conv1 = nn.Conv2d(channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, image,labels=None):
        # the labels parameter allows us to finetune our model
        # with the Trainer API easily
        x = F.relu(F.max_pool2d(self.conv1(image), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        logits = F.log_softmax(x,dim=-1)
        if labels != None :
          # this will make your AI compatible with the trainer API
          loss = self.criterion(logits, labels)
          return {"loss": loss, "logits": logits}
        return logits

the forward method needs to take at least 2 things :

  • image: or X, meaning the data that will pass through the model
  • labels: or the Y, if present we will calculate and return the loss in dictionary type or a ModelOutput type

depending on our dataset columns we can adapt the forward method to match it, note that data is passed as kwargs.

Train the model

First, let's initialize the model

# The number 3 will be stored in the config.json file
model = BasicNet(channels=3) #RGB

let's add an optional compute metrics section here

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

The collate function serves as a way to collate data in a batch in an appropriate way

to test your collate function before you pass it to the trainer API you can use the following code collate_fn(train.select(range(8))) this will select the first 8 samples of the dataset and pass it through the collate_fn then you can verify if all data shapes match your needs.

def collate_fn(examples):
    images = []
    labels = []
    for example in examples:
        images.append((example["image"]))
        labels.append(example["labels"])

    images = torch.stack(images)
    labels = torch.tensor(labels)
    return {"image": images, "labels": labels}

Now we define the parameters needed to train our AI model, it is recommended that you read the documentation about the TrainingArguments to find out what parameters match your requirements.

training_args = TrainingArguments(
    output_dir="my_mnist_model",
    # remove_unused_columns=False,
    evaluation_strategy="steps",
    save_strategy="epoch",
    learning_rate=5e-4,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=2,
    logging_steps=100,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train,
    eval_dataset=test,
    compute_metrics=compute_metrics,
)

trainer.train()

The trainer API only pushes the weights without the config (init parameters)

to fix this, you should push your model manually after the training is done

model.push_to_hub("my_mnist_model")

and you're all done 🥳🥳🥳

Load model from the hub

to load your model from the hub, there's no need to reinitialize it again nor download the weights manually, you can directly use the class to do this for you.

with ModelHubMixin manual loading
new_model = BasicNet.from_pretrained("not-lain/my_mnist_model")
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
snapshot_download(repo_id="not-lain/my_mnist_model",local_dir="temp_folder")

# or use the parameters in the config.json
new_model = BasicNet(3)
weights = load_file("temp_folder/model.safetensors")
new_model.load_state_dict(weights)

Outro

Hope this has been informative and that you learned a lot from it, if you loved this blog post consider upvoting it 🤗 if you have any questions or queries do not hesitate to open a discussion and tag me or to reach out in my DMs.

fun fact, I contributed a little to the ModelHubMixin classes used in this article, so hope you loved it 🤗

thanks for reading this blog post, love you all ❤️