7.1. Fine-Tuning for Classification
Reading time: 5 minutes
What is
Il fine-tuning è il processo di prendere un modello pre-addestrato che ha appreso modelli linguistici generali da enormi quantità di dati e adattarlo per eseguire un compito specifico o per comprendere il linguaggio specifico di un dominio. Questo si ottiene continuando l'addestramento del modello su un dataset più piccolo e specifico per il compito, permettendogli di regolare i suoi parametri per adattarsi meglio alle sfumature dei nuovi dati, sfruttando al contempo la vasta conoscenza che ha già acquisito. Il fine-tuning consente al modello di fornire risultati più accurati e pertinenti in applicazioni specializzate senza la necessità di addestrare un nuovo modello da zero.
tip
Poiché il pre-addestramento di un LLM che "comprende" il testo è piuttosto costoso, di solito è più facile ed economico fare fine-tuning su modelli pre-addestrati open source per eseguire un compito specifico che vogliamo che esegua.
tip
L'obiettivo di questa sezione è mostrare come fare fine-tuning su un modello già pre-addestrato, in modo che invece di generare nuovo testo, l'LLM selezioni e fornisca le probabilità che il testo fornito venga categorizzato in ciascuna delle categorie date (come se un testo fosse spam o meno).
Preparing the data set
Data set size
Naturalmente, per fare fine-tuning su un modello è necessario avere dei dati strutturati da utilizzare per specializzare il tuo LLM. Nell'esempio proposto in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 è fine-tuned per rilevare se un'email è spam o meno utilizzando i dati di https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Questo dataset contiene molti più esempi di "non spam" che di "spam", quindi il libro suggerisce di utilizzare solo tanti esempi di "non spam" quanti di "spam" (rimuovendo quindi dal set di addestramento tutti gli esempi extra). In questo caso, erano 747 esempi di ciascuno.
Poi, il 70% del dataset è utilizzato per l'addestramento, il 10% per la validazione e il 20% per il test.
- Il set di validazione è utilizzato durante la fase di addestramento per fare fine-tuning degli iperparametri del modello e prendere decisioni sull'architettura del modello, aiutando effettivamente a prevenire l'overfitting fornendo feedback su come il modello si comporta su dati non visti. Permette miglioramenti iterativi senza pregiudicare la valutazione finale.
- Questo significa che, sebbene i dati inclusi in questo dataset non siano utilizzati direttamente per l'addestramento, vengono utilizzati per ottimizzare i migliori iperparametri, quindi questo set non può essere utilizzato per valutare le prestazioni del modello come quello di test.
- Al contrario, il set di test è utilizzato solo dopo che il modello è stato completamente addestrato e tutti gli aggiustamenti sono stati completati; fornisce una valutazione imparziale della capacità del modello di generalizzare a nuovi dati non visti. Questa valutazione finale sul set di test fornisce un'indicazione realistica di come ci si aspetta che il modello si comporti nelle applicazioni del mondo reale.
Entries length
Poiché l'esempio di addestramento si aspetta voci (testo delle email in questo caso) della stessa lunghezza, è stato deciso di rendere ogni voce grande quanto la più grande aggiungendo gli id di <|endoftext|>
come padding.
Initialize the model
Utilizzando i pesi pre-addestrati open-source, inizializza il modello per l'addestramento. Abbiamo già fatto questo prima e seguendo le istruzioni di https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb puoi farlo facilmente.
Classification head
In questo esempio specifico (predire se un testo è spam o meno), non siamo interessati a fare fine-tuning secondo il vocabolario completo di GPT2, ma vogliamo solo che il nuovo modello dica se l'email è spam (1) o meno (0). Pertanto, andremo a modificare l'ultimo strato che fornisce le probabilità per token del vocabolario per uno che fornisce solo le probabilità di essere spam o meno (quindi come un vocabolario di 2 parole).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parametri da ottimizzare
Per ottimizzare rapidamente, è più facile non ottimizzare tutti i parametri ma solo alcuni finali. Questo perché è noto che i livelli inferiori catturano generalmente strutture linguistiche di base e semantiche applicabili. Quindi, ottimizzare solo gli ultimi livelli è di solito sufficiente e più veloce.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entries to use for training
Nelle sezioni precedenti, il LLM è stato addestrato riducendo la perdita di ogni token previsto, anche se quasi tutti i token previsti erano nella frase di input (solo 1 alla fine era realmente previsto) affinché il modello comprendesse meglio la lingua.
In questo caso ci interessa solo che il modello sia in grado di prevedere se il modello è spam o meno, quindi ci interessa solo l'ultimo token previsto. Pertanto, è necessario modificare le nostre precedenti funzioni di perdita di addestramento per tenere conto solo di quel token.
Questo è implementato in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb come:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Nota come per ogni batch siamo interessati solo ai logits dell'ultimo token previsto.
Codice completo per la classificazione fine-tune di GPT2
Puoi trovare tutto il codice per fine-tunare GPT2 come classificatore di spam in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb