7.1. Ajustement fin pour la classification
Reading time: 8 minutes
tip
Apprenez et pratiquez le hacking AWS :HackTricks Training AWS Red Team Expert (ARTE)
Apprenez et pratiquez le hacking GCP : HackTricks Training GCP Red Team Expert (GRTE)
Apprenez et pratiquez le hacking Azure :
HackTricks Training Azure Red Team Expert (AzRTE)
Soutenir HackTricks
- Vérifiez les plans d'abonnement !
- Rejoignez le 💬 groupe Discord ou le groupe telegram ou suivez-nous sur Twitter 🐦 @hacktricks_live.
- Partagez des astuces de hacking en soumettant des PR au HackTricks et HackTricks Cloud dépôts github.
Qu'est-ce que c'est
L'ajustement fin est le processus qui consiste à prendre un modèle pré-entraîné ayant appris des schémas linguistiques généraux à partir de vastes quantités de données et à l'adapter pour effectuer une tâche spécifique ou comprendre un langage spécifique à un domaine. Cela se fait en poursuivant l'entraînement du modèle sur un ensemble de données plus petit et spécifique à la tâche, lui permettant d'ajuster ses paramètres pour mieux s'adapter aux nuances des nouvelles données tout en tirant parti des vastes connaissances qu'il a déjà acquises. L'ajustement fin permet au modèle de fournir des résultats plus précis et pertinents dans des applications spécialisées sans avoir besoin d'entraîner un nouveau modèle depuis le début.
tip
Comme pré-entraîner un LLM qui "comprend" le texte est assez coûteux, il est généralement plus facile et moins cher d'ajuster des modèles pré-entraînés open source pour effectuer une tâche spécifique que nous souhaitons qu'il réalise.
tip
L'objectif de cette section est de montrer comment ajuster un modèle déjà pré-entraîné afin qu'au lieu de générer un nouveau texte, le LLM sélectionne et donne les probabilités que le texte donné soit catégorisé dans chacune des catégories données (comme si un texte est un spam ou non).
Préparation de l'ensemble de données
Taille de l'ensemble de données
Bien sûr, pour ajuster un modèle, vous avez besoin de données structurées à utiliser pour spécialiser votre LLM. Dans l'exemple proposé dans https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 est ajusté pour détecter si un email est un spam ou non en utilisant les données de https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Cet ensemble de données contient beaucoup plus d'exemples de "non spam" que de "spam", donc le livre suggère de n'utiliser que autant d'exemples de "non spam" que de "spam" (en supprimant ainsi tous les exemples supplémentaires des données d'entraînement). Dans ce cas, il y avait 747 exemples de chaque.
Ensuite, 70% de l'ensemble de données est utilisé pour l'entraînement, 10% pour la validation et 20% pour les tests.
- L'ensemble de validation est utilisé pendant la phase d'entraînement pour ajuster les hyperparamètres du modèle et prendre des décisions concernant l'architecture du modèle, aidant ainsi à prévenir le surajustement en fournissant des retours sur la performance du modèle sur des données non vues. Cela permet des améliorations itératives sans biaiser l'évaluation finale.
- Cela signifie que bien que les données incluses dans cet ensemble de données ne soient pas utilisées directement pour l'entraînement, elles sont utilisées pour ajuster les meilleurs hyperparamètres, donc cet ensemble ne peut pas être utilisé pour évaluer la performance du modèle comme celui des tests.
- En revanche, l'ensemble de test est utilisé uniquement après que le modèle a été entièrement entraîné et que tous les ajustements sont terminés ; il fournit une évaluation impartiale de la capacité du modèle à généraliser sur de nouvelles données non vues. Cette évaluation finale sur l'ensemble de test donne une indication réaliste de la façon dont le modèle est censé performer dans des applications réelles.
Longueur des entrées
Comme l'exemple d'entraînement attend des entrées (texte des emails dans ce cas) de la même longueur, il a été décidé de rendre chaque entrée aussi grande que la plus grande en ajoutant les ids de <|endoftext|>
comme remplissage.
Initialiser le modèle
En utilisant les poids pré-entraînés open source, initialisez le modèle pour l'entraînement. Nous avons déjà fait cela auparavant et en suivant les instructions de https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, vous pouvez facilement le faire.
Tête de classification
Dans cet exemple spécifique (prédire si un texte est un spam ou non), nous ne sommes pas intéressés à ajuster selon le vocabulaire complet de GPT2 mais nous voulons seulement que le nouveau modèle indique si l'email est un spam (1) ou non (0). Par conséquent, nous allons modifier la couche finale qui donne les probabilités par token du vocabulaire pour une qui ne donne que les probabilités d'être un spam ou non (comme un vocabulaire de 2 mots).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Paramètres à ajuster
Afin d'ajuster rapidement, il est plus facile de ne pas ajuster tous les paramètres mais seulement certains derniers. Cela est dû au fait qu'il est connu que les couches inférieures capturent généralement des structures linguistiques de base et des sémantiques applicables. Donc, juste ajuster les dernières couches est généralement suffisant et plus rapide.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entrées à utiliser pour l'entraînement
Dans les sections précédentes, le LLM a été entraîné en réduisant la perte de chaque jeton prédit, même si presque tous les jetons prédits étaient dans la phrase d'entrée (seul 1 à la fin était vraiment prédit) afin que le modèle comprenne mieux la langue.
Dans ce cas, nous ne nous soucions que de la capacité du modèle à prédire si le modèle est un spam ou non, donc nous ne nous soucions que du dernier jeton prédit. Par conséquent, il est nécessaire de modifier nos précédentes fonctions de perte d'entraînement pour ne prendre en compte que ce jeton.
Ceci est implémenté dans https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb comme :
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Notez que pour chaque lot, nous ne sommes intéressés que par les logits du dernier token prédit.
Code complet de classification fine-tuné pour GPT2
Vous pouvez trouver tout le code pour fine-tuner GPT2 en tant que classificateur de spam dans https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb
Références
tip
Apprenez et pratiquez le hacking AWS :HackTricks Training AWS Red Team Expert (ARTE)
Apprenez et pratiquez le hacking GCP : HackTricks Training GCP Red Team Expert (GRTE)
Apprenez et pratiquez le hacking Azure :
HackTricks Training Azure Red Team Expert (AzRTE)
Soutenir HackTricks
- Vérifiez les plans d'abonnement !
- Rejoignez le 💬 groupe Discord ou le groupe telegram ou suivez-nous sur Twitter 🐦 @hacktricks_live.
- Partagez des astuces de hacking en soumettant des PR au HackTricks et HackTricks Cloud dépôts github.