top of page
Search

Fine-Tuning Hugging Face Language Models with Pytorch Lightning

Updated: Feb 19

Nowadays it's extremely common to find yourself working with large language models on a daily basis. Either as a tool to aid you in your job (e.g. ChatGPT) or because you work with the models themselves: as part of an application you are developing (e.g. a chatbot), or, like myself, you are doing research on natural language processing or other similar domains.

If you are currently working or researching with large language models, there are two ways to go about it: you either develop/research on top of some API provided by OpenAI, Google, Microsoft, etc.; or you have to work on the models by yourself (e.g. training, fine-tuning, doing inference on top of them, etc.). For those facing the latter scenario, this article might be right up your alley. I will be presenting two of the most essential tools I use on my day to day job that give me all the flexibility I require to do my job right and how you can combine these tools for your advantage.


Hugging Face & Lightning AI Logo
Hugging Face + Lightning AI

Hugging Face


When you have to train or fine-tune large language models by yourself, like I do, a great tool you might have come across at some point is Hugging Face [1], particularly their Transformers API [2]. This company provides access to thousands of state-of-the-art pre-trained models [3] and provides APIs to make it easy for anyone to train, share and build anything on top of them.

Hugging Face provides the interfaces for a variety of models across different tasks: natural language processing, computer vision, audio and multi-modal; and supports different back-end tools such as PyTorch [4], JAX [5] and TensorFlow [6]. They provide a large community of developers, and a very extensive documentation [7] over their APIs.

PyTorch Lightning


PyTorch Lightning [8] started as a library on top of PyTorch that provided the necessary framework to abstract and simplify the process of training models. They have evolved and eventually became Lightning AI [9], a company dedicated to simplify the process of training, evaluating, deploying and maintaining deep learning models. What I like most about this framework is the seamless integration it provides in order to use multiple GPUs [10], experiment tracking tools [11], and the overall simplification of the whole training and evaluation process.

Why should we use Hugging Face with PyTorch Lightning?


If we're being honest, the API provided by Hugging Face [12] for fine-tuning their models is usually a good enough solution for most cases. In particular, if you are not trying to fine-tune any model and just use the inference, their pipeline API [13] is more than enough. Moreover, when dealing with some classical tasks such as text classification, their Trainer API and their library should be more than enough. And it has been constantly upgrading, so it's highly likely that in the future this will be improved even further.

But there is a reality: the library has to keep compatibility with (so far at least) 3 different deep learning frameworks: PyTorch, TensorFlow and JAX. As a consequence, there are certain features from PyTorch you won't be able to fully utilize if you limit yourself to training and fine-tuning the models only using their API.

In my case, I find their version of Multiple GPUs training and parallelism a little confusing (but maybe that's me because I have been working with Lightning for quite some time now). On the other hand, I like the Lighting integration [14] with MLFlow [15], another tool I use on my day to day because it makes the job of tracking experiments much easier.

Training Hugging Face Transformers with PyTorch Lightning

Installation


Before we delve into how to train a Hugging Face transformer using the framework provided by Pytorch Lightning, we need to install the Hugging Face and the Lightning libraries:

pip install lightning transformers datasets

This will install the two libraries we need for development, and for this specific case, an extra library datasets [16] from Hugging Face as well, which offers a large variety of datasets for you to experiment and share.


LightningDataModule


We start by defining the LightningDataModule [17]. This is an abstraction provided by PyTorch Lightning that encapsulates all the steps needed to process a dataset: download, tokenize, clean, transform and any other form of pre-processing is done within this module:

from datasets import load_dataset
from lightning import LightningDataModule
from transformers import AutoTokenizer

class HFDataModule(LightningDataModule):
    def __init__(self, tokenizer_name):
        super().__init__()
        self.tokenizer_name = tokenizer_name

    def _tokenize_function(self, batch):
        return self.tokenizer(batch["text"], 
                              padding="max_length",
                              truncation=True, 
                              max_length=64)
    
    def prepare_data(self):
        # Download the dataset and tokenizer
        load_dataset("rotten_tomatoes")
        self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)

    def setup(self, stage):
        dataset = load_dataset("rotten_tomatoes")
        dataset = dataset.rename_column("label", "labels")
        self.dataset = dataset.map(self._tokenize_function, batched=True)
        self.dataset.set_format(
            type="torch",
            columns=["input_ids", "token_type_ids", "attention_mask", "labels"]
        )

    def train_dataloader(self):
        return DataLoader(self.dataset["train"],
                          batch_size=32,
                          shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.dataset["validation"],
                          batch_size=32,
                          shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.dataset["test"],
                          batch_size=32,
                          shuffle=False)

The HFDataModule class takes the name of the Hugging Face tokenizer (which is usually the same one of the model) and it will load the "rotten_tomatoes" dataset [18] from Hugging Face. In the __init__ method, the most important part is calling super() in order to set up the base LightningDataModule class.

The prepare_data() is a method utilized by the Lightning Trainer (which we'll see in the next sections) and is useful for things as data download: in this case the method calls the load_dataset() method from Hugging Face, but only to download the dataset's files into the system, it also downloads the tokenizer's files and loads it.

The setup() is another method called by the Trainer, usually for pre-processing and splitting the data, it takes a stage argument that can be 'fit', 'validate', 'test', or 'predict', depending on what the Trainer is going to do. In this case we ignore it, and set the Hugging Face dataset (which is already split in train, test and validation).

First loading the dataset into an attribute, renaming the column "label" to "labels" (this is because Hugging Face expects the name to be "labels" but when using their APIs that rename is handled internally), it tokenizes the text, sets the format of the tensors to PyTorch and selects the columns that the Hugging Face model needs.

Finally each <split>_dataloader method loads the DataLoaders for each of the training splits.


LightningModule


Next we build the LightningModule [19], which is going to be the model to train. In this case, we set a module that is basically a wrapper around the Hugging Face Model for Sentence Classification [20]. The advantage here is that as we are defining a subclass of LightningModule, which is a subclass of PyTorch's nn.Module, it gives us absolute freedom on what to use with the model: extra layers, optimization, scheduler, loss function, etc.

from lightning import LightningModule
from torch.optim import AdamW
from transformers import AutoModelForSequenceClassification

class TransformerModule(LightningModule):
    def __init__(self, model_name):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForSequenceClassification.from_pretrained(
            model_name
        )

    def forward(self, **inputs):
        return self.model(**inputs)
    
    def training_step(self, batch, batch_idx):
        loss = self(**batch).loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self(**batch).loss
        self.log("val_loss", loss, on_epoch=True)
        return loss

    def test_step(self, batch, batch_idx):
        loss = self(**batch).loss
        self.log("test_loss", loss, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-3)

With the aid of Lightning, we have a fully functional model ready for training. It loads the Hugging Face model, and declares the corresponding steps of training, test and validation, logging the loss in each of them. We also set up the optimizer to use.

Training and Evaluation

Now that we have the Data Module and the Model itself, we need to set up the training loop and the evaluation. But that's where all of PyTorch Lightning's magic happens:

from lightning import Trainer
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

MODEL_NAME = "bert-base-uncased"

dm = HFDataModule(MODEL_NAME)
model = TransformerModule(MODEL_NAME)

early_stopping = EarlyStopping(monitor="val_loss", patience=3)
model_checkpoints = ModelCheckpoint("./checkpoints",
                                    monitor="val_loss",
                                    save_top_k=1)
trainer = Trainer(max_epochs=5, callbacks=[early_stopping, model_checkpoints])

trainer.fit(model, datamodule=dm)
trainer.test(ckpt_path=model_checkpoints.best_model_path, datamodule=dm)

The previous snippet runs the whole training process using the "bert-base-uncased" model from Hugging Face, and fine-tuning it for the dataset of "rotten_tomatoes", which is a sentiment classification dataset over movie reviews.

First we set up the data module and the models using "bert-base-uncased" as our base model, which loads the correct model and tokenizer.

We continue by creating an instance of the Trainer [21]. This is the class behind all of PyTorch Lightning's magic, it is the one that will take the LightningModule and run the training, validation and test loops over the different splits given by the LightningDataModule. In this case we limit the initialization parameters to the number of epochs and we add a couple of callbacks: The early stopping callback monitors the metric given by parameter, this metric should be logged by the LightningModule's log() method (in this case it will look for the "val_loss" that is logged in the corresponding validation_step that is part of the TransformerModule class we defined above); the ModelCheckpoint callback is used to store the best checkpoint result based on the monitored metric (i.e., "val_loss").

The Trainer module is a very versatile and powerful abstraction of the training loop, it provides all the necessary tools to run multiple GPUs (or other types of accelerators), setup different loggers, setting up a maximum amount of time, steps or epochs for running the training loop, etc.

The fit() method takes the model and data module we created before and runs for a maximum of 5 epochs (or will stop before if the early stopping condition is met).

We could save_checkpoint method of the Trainer to save a checkpoint after the training stops. In this case however, since we have the ModelCheckpoint callback, we know that the best model is saved in model_checkpoints.best_model_path.

We use the test method on the best model checkpoint to run the evaluation over the test data.

Although Lightning has access to many different loggers, by default it will run a TensorBoardLogger [22] under the "lightning_logs" directory (in the same directory where the script was run), and by default Lightning AI will install TensorBoard. Thus, we can check the logged results:

tensorboard --logdir ./lightning_logs

We now have the possibility of running multiple experiments and easily compare the results of these experiments with the help of TensorBoard.

Final Thoughts


Hugging Face has an excellent community and repository of pre-trained models of very different nature and with a lot of different possibilities to explore. They also offer a suite of datasets already available to download, most of them pre-processed and already splitted in train, test and validation. They offer a nice set of APIs for running pipelines and fine-tuning any of the models that they provide. However, in my experience, their API is a little tricky and difficult to use, especially when having to go beyond what is doable out-of-the-box. Besides, as there are multiple supported backends, it makes it harder to take advantage of all the possibilities a specific backend has to offer, in our case PyTorch.

PyTorch Lightning provides a very good wrapper around PyTorch, giving their users access to lots of abstractions that simplify the training, testing and validation process. Their LightningModule has access to the whole PyTorch technology basically rendering full flexibility in using any loss, optimizer, scheduler, etc., from PyTorch. Their Trainer module grants the user access to different things to maximize the deep learning workflow such as multi GPU training, early stopping or model checkpointing.

If we combine them we can have all the power of Hugging Face's Transformers with all the flexibility and scalability of the PyTorch Lightning framework. A real winning combination.


References

[1] Hugging Face Inc. https://huggingface.co/

[2] Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Remi Louf, Morgan Funtowicz, Joe Davison, Sam Shleifer, Patrick von Platen, Clara Ma, Yacine Jernite, Julien Plu, Canwen Xu, Teven Le Scao, Sylvain Gugger, et al.. 2020. Transformers: State-of-the-Art Natural Language Processing. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pages 38–45, Online. Association for Computational Linguistics.

[3] Hugging Face Inc. Models. https://huggingface.co/models

[5] JAX: High-Performance Array Computing. https://jax.readthedocs.io/en/latest/

[7] Hugging Face Inc. Documentation. https://huggingface.co/docs

[8] Lightning AI. PyTorch Lightning Documentation. https://lightning.ai/docs/pytorch/stable/

[9] Lightning AI. https://lightning.ai/

[11] Lightning AI. Track and Visualize Experiments (Advanced). https://lightning.ai/docs/pytorch/stable/visualize/logging_advanced.html#logger

[12] Hugging Face Inc. Fine-tune a pretrained model. https://huggingface.co/docs/transformers/training

[13] Hugging Face Inc. Pipelines for inference. https://huggingface.co/docs/transformers/pipeline_tutorial

[15] MLFlow. ML and GenAI made simple. https://mlflow.org/

[16] Hugging Face Inc. Datasets. https://huggingface.co/datasets

[18] Hugging Face Inc. Datasets: rotten_tomatoes. https://huggingface.co/datasets/rotten_tomatoes

12 views0 comments

Recent Posts

See All
bottom of page