7.1. Feinabstimmung für die Klassifikation
Reading time: 5 minutes
Was ist
Feinabstimmung ist der Prozess, ein vortrainiertes Modell zu nehmen, das allgemeine Sprachmuster aus großen Datenmengen gelernt hat, und es anzupassen, um eine spezifische Aufgabe auszuführen oder domänenspezifische Sprache zu verstehen. Dies wird erreicht, indem das Training des Modells auf einem kleineren, aufgabenbezogenen Datensatz fortgesetzt wird, wodurch es seine Parameter anpassen kann, um besser auf die Nuancen der neuen Daten einzugehen, während es das breite Wissen, das es bereits erworben hat, nutzt. Feinabstimmung ermöglicht es dem Modell, genauere und relevantere Ergebnisse in spezialisierten Anwendungen zu liefern, ohne ein neues Modell von Grund auf neu trainieren zu müssen.
tip
Da das Vortrainieren eines LLM, das den Text "versteht", ziemlich teuer ist, ist es normalerweise einfacher und günstiger, Open-Source-vortrainierte Modelle für eine spezifische Aufgabe, die wir möchten, dass es ausführt, fein abzustimmen.
tip
Das Ziel dieses Abschnitts ist es zu zeigen, wie man ein bereits vortrainiertes Modell fein abstimmt, sodass das LLM anstelle von neuem Text die Wahrscheinlichkeiten angibt, dass der gegebene Text in jede der gegebenen Kategorien eingeordnet wird (zum Beispiel, ob ein Text Spam ist oder nicht).
Vorbereitung des Datensatzes
Größe des Datensatzes
Um ein Modell fein abzustimmen, benötigt man natürlich einige strukturierte Daten, um das LLM zu spezialisieren. Im vorgeschlagenen Beispiel in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb wird GPT2 fein abgestimmt, um zu erkennen, ob eine E-Mail Spam ist oder nicht, unter Verwendung der Daten von https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Dieser Datensatz enthält viel mehr Beispiele von "nicht Spam" als von "Spam", daher schlägt das Buch vor, nur so viele Beispiele von "nicht Spam" wie von "Spam" zu verwenden (daher werden alle zusätzlichen Beispiele aus den Trainingsdaten entfernt). In diesem Fall waren es 747 Beispiele von jedem.
Dann werden 70% des Datensatzes für das Training, 10% für die Validierung und 20% für das Testen verwendet.
- Der Validierungsdatensatz wird während der Trainingsphase verwendet, um die Hyperparameter des Modells fein abzustimmen und Entscheidungen über die Modellarchitektur zu treffen, wodurch effektiv Überanpassung verhindert wird, indem Feedback darüber gegeben wird, wie das Modell bei ungesehenen Daten abschneidet. Er ermöglicht iterative Verbesserungen, ohne die endgültige Bewertung zu verzerren.
- Das bedeutet, dass obwohl die in diesem Datensatz enthaltenen Daten nicht direkt für das Training verwendet werden, sie dazu dienen, die besten Hyperparameter zu optimieren, sodass dieser Datensatz nicht zur Bewertung der Leistung des Modells wie der Testdatensatz verwendet werden kann.
- Im Gegensatz dazu wird der Testdatensatz nur nach dem vollständigen Training des Modells und nach Abschluss aller Anpassungen verwendet; er bietet eine unvoreingenommene Bewertung der Fähigkeit des Modells, auf neue, ungesehene Daten zu verallgemeinern. Diese endgültige Bewertung des Testdatensatzes gibt einen realistischen Hinweis darauf, wie das Modell in realen Anwendungen abschneiden wird.
Länge der Einträge
Da das Trainingsbeispiel Einträge (E-Mail-Text in diesem Fall) derselben Länge erwartet, wurde beschlossen, jeden Eintrag so groß wie den größten zu machen, indem die IDs von <|endoftext|>
als Padding hinzugefügt werden.
Modell initialisieren
Verwenden Sie die Open-Source-vortrainierten Gewichte, um das Modell für das Training zu initialisieren. Wir haben dies bereits zuvor getan und folgen den Anweisungen von https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, können Sie dies leicht tun.
Klassifikationskopf
In diesem speziellen Beispiel (Vorhersage, ob ein Text Spam ist oder nicht) sind wir nicht daran interessiert, die Feinabstimmung gemäß dem vollständigen Vokabular von GPT2 vorzunehmen, sondern wir möchten nur, dass das neue Modell sagt, ob die E-Mail Spam (1) oder nicht (0) ist. Daher werden wir die letzte Schicht, die die Wahrscheinlichkeiten pro Token des Vokabulars angibt, so ändern, dass sie nur die Wahrscheinlichkeiten angibt, Spam oder nicht zu sein (also wie ein Vokabular von 2 Wörtern).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parameter zum Abstimmen
Um schnell abzustimmen, ist es einfacher, nicht alle Parameter, sondern nur einige finale zu optimieren. Das liegt daran, dass bekannt ist, dass die unteren Schichten im Allgemeinen grundlegende Sprachstrukturen und anwendbare Semantiken erfassen. Daher ist es in der Regel ausreichend und schneller, nur die letzten Schichten abzustimmen.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Einträge zur Verwendung für das Training
In den vorherigen Abschnitten wurde das LLM trainiert, indem der Verlust jedes vorhergesagten Tokens reduziert wurde, obwohl fast alle vorhergesagten Tokens im Eingabesatz waren (nur 1 am Ende wurde wirklich vorhergesagt), um dem Modell zu helfen, die Sprache besser zu verstehen.
In diesem Fall interessiert uns nur, ob das Modell Spam ist oder nicht, daher konzentrieren wir uns nur auf das letzte vorhergesagte Token. Daher ist es notwendig, unsere vorherigen Trainingsverlustfunktionen so zu modifizieren, dass nur dieses Token berücksichtigt wird.
Dies ist implementiert in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb als:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Beachten Sie, dass wir für jede Charge nur an den Logits des letzten vorhergesagten Tokens interessiert sind.
Vollständiger GPT2 Feinabstimmungs-Klassifizierungscode
Sie finden den gesamten Code zur Feinabstimmung von GPT2 als Spam-Klassifizierer in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb