7.1. Dostosowywanie do klasyfikacji
Reading time: 5 minutes
Czym jest
Dostosowywanie to proces, w którym bierze się wstępnie wytrenowany model, który nauczył się ogólnych wzorców językowych z ogromnych ilości danych i dostosowuje go do wykonywania specyficznego zadania lub rozumienia języka specyficznego dla danej dziedziny. Osiąga się to poprzez kontynuowanie treningu modelu na mniejszym, specyficznym dla zadania zbiorze danych, co pozwala mu dostosować swoje parametry, aby lepiej odpowiadały niuansom nowych danych, jednocześnie wykorzystując szeroką wiedzę, którą już zdobył. Dostosowywanie umożliwia modelowi dostarczanie dokładniejszych i bardziej odpowiednich wyników w specjalistycznych zastosowaniach bez potrzeby trenowania nowego modelu od podstaw.
tip
Ponieważ wstępne trenowanie LLM, który "rozumie" tekst, jest dość kosztowne, zazwyczaj łatwiej i taniej jest dostosować otwarte modele wstępnie wytrenowane do wykonywania konkretnego zadania, które chcemy, aby realizował.
tip
Celem tej sekcji jest pokazanie, jak dostosować już wstępnie wytrenowany model, aby zamiast generować nowy tekst, LLM wybierał prawdopodobieństwa przypisania danego tekstu do każdej z podanych kategorii (na przykład, czy tekst jest spamem, czy nie).
Przygotowanie zbioru danych
Rozmiar zbioru danych
Oczywiście, aby dostosować model, potrzebujesz pewnych uporządkowanych danych do specjalizacji swojego LLM. W przykładzie zaproponowanym w https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 jest dostosowywany do wykrywania, czy e-mail jest spamem, czy nie, przy użyciu danych z https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Ten zbiór danych zawiera znacznie więcej przykładów "nie spam" niż "spam", dlatego książka sugeruje, aby używać tylko tylu przykładów "nie spam", co "spam" (w związku z tym, usuwając z danych treningowych wszystkie dodatkowe przykłady). W tym przypadku było to 747 przykładów każdego.
Następnie, 70% zbioru danych jest używane do treningu, 10% do walidacji, a 20% do testowania.
- Zbiór walidacyjny jest używany podczas fazy treningu do dostosowywania hiperparametrów modelu i podejmowania decyzji dotyczących architektury modelu, skutecznie pomagając zapobiegać przeuczeniu, dostarczając informacji zwrotnej na temat tego, jak model radzi sobie z nieznanymi danymi. Umożliwia to iteracyjne poprawki bez wprowadzania stronniczości w końcowej ocenie.
- Oznacza to, że chociaż dane zawarte w tym zbiorze danych nie są używane bezpośrednio do treningu, są używane do dostrojenia najlepszych hiperparametrów, więc ten zbiór nie może być używany do oceny wydajności modelu, jak zbiór testowy.
- W przeciwieństwie do tego, zbiór testowy jest używany tylko po pełnym wytrenowaniu modelu i zakończeniu wszystkich dostosowań; zapewnia bezstronną ocenę zdolności modelu do generalizacji na nowe, nieznane dane. Ta końcowa ocena na zbiorze testowym daje realistyczne wskazanie, jak model ma się sprawować w rzeczywistych zastosowaniach.
Długość wpisów
Ponieważ przykład treningowy oczekuje wpisów (tekstów e-maili w tym przypadku) o tej samej długości, zdecydowano się, aby każdy wpis był tak duży, jak największy, dodając identyfikatory <|endoftext|>
jako wypełnienie.
Inicjalizacja modelu
Używając otwartych, wstępnie wytrenowanych wag, inicjalizuj model do treningu. Już to zrobiliśmy wcześniej i postępując zgodnie z instrukcjami z https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, możesz to łatwo zrobić.
Głowa klasyfikacji
W tym konkretnym przykładzie (przewidywanie, czy tekst jest spamem, czy nie) nie interesuje nas dostosowywanie zgodnie z pełnym słownictwem GPT2, ale chcemy, aby nowy model tylko określał, czy e-mail jest spamem (1), czy nie (0). Dlatego zamierzamy zmodyfikować ostatnią warstwę, która podaje prawdopodobieństwa dla tokenów słownictwa, na taką, która podaje tylko prawdopodobieństwa bycia spamem lub nie (więc jakby słownictwo składające się z 2 słów).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parametry do dostrojenia
Aby szybko dostroić model, łatwiej jest nie dostrajać wszystkich parametrów, a jedynie niektóre końcowe. Dzieje się tak, ponieważ wiadomo, że niższe warstwy zazwyczaj uchwycają podstawowe struktury językowe i semantykę. Dlatego dostrojenie tylko ostatnich warstw jest zazwyczaj wystarczające i szybsze.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entries to use for training
W poprzednich sekcjach LLM był trenowany, redukując stratę każdego przewidywanego tokena, mimo że prawie wszystkie przewidywane tokeny znajdowały się w zdaniu wejściowym (tylko 1 na końcu był naprawdę przewidywany), aby model lepiej zrozumiał język.
W tym przypadku interesuje nas tylko to, czy model jest spamem, czy nie, więc interesuje nas tylko ostatni przewidywany token. Dlatego konieczne jest zmodyfikowanie naszych wcześniejszych funkcji straty treningowej, aby uwzględniały tylko ten token.
To jest zaimplementowane w https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb jako:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Zauważ, że dla każdej partii interesują nas tylko logity ostatniego przewidywanego tokena.
Pełny kod klasyfikacji fine-tune GPT2
Możesz znaleźć cały kod do fine-tuningu GPT2 jako klasyfikatora spamu w https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb