7.1. Fine-Tuning for Classification
Reading time: 5 minutes
What is
Η προσαρμογή είναι η διαδικασία λήψης ενός προεκπαιδευμένου μοντέλου που έχει μάθει γενικά γλωσσικά μοτίβα από τεράστιες ποσότητες δεδομένων και προσαρμογής του για να εκτελεί μια συγκεκριμένη εργασία ή να κατανοεί γλώσσα συγκεκριμένης περιοχής. Αυτό επιτυγχάνεται με τη συνέχιση της εκπαίδευσης του μοντέλου σε ένα μικρότερο, ειδικό για την εργασία σύνολο δεδομένων, επιτρέποντάς του να προσαρμόσει τις παραμέτρους του ώστε να ταιριάζει καλύτερα στις αποχρώσεις των νέων δεδομένων, ενώ αξιοποιεί τη γενική γνώση που έχει ήδη αποκτήσει. Η προσαρμογή επιτρέπει στο μοντέλο να παρέχει πιο ακριβή και σχετικά αποτελέσματα σε εξειδικευμένες εφαρμογές χωρίς την ανάγκη εκπαίδευσης ενός νέου μοντέλου από την αρχή.
tip
Καθώς η προεκπαίδευση ενός LLM που "κατανοεί" το κείμενο είναι αρκετά δαπανηρή, είναι συνήθως πιο εύκολο και φθηνότερο να προσαρμόσουμε ανοιχτού κώδικα προεκπαιδευμένα μοντέλα για να εκτελούν μια συγκεκριμένη εργασία που θέλουμε να εκτελούν.
tip
Ο στόχος αυτής της ενότητας είναι να δείξει πώς να προσαρμόσουμε ένα ήδη προεκπαιδευμένο μοντέλο, έτσι ώστε αντί να δημιουργεί νέο κείμενο, το LLM να επιλέγει να δώσει τις πιθανότητες του δεδομένου κειμένου να κατηγοριοποιηθεί σε κάθε μία από τις δεδομένες κατηγορίες (όπως αν ένα κείμενο είναι spam ή όχι).
Preparing the data set
Data set size
Φυσικά, για να προσαρμόσετε ένα μοντέλο χρειάζεστε κάποια δομημένα δεδομένα για να εξειδικεύσετε το LLM σας. Στο παράδειγμα που προτείνεται στο https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, το GPT2 προσαρμόζεται για να ανιχνεύει αν ένα email είναι spam ή όχι χρησιμοποιώντας τα δεδομένα από https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Αυτό το σύνολο δεδομένων περιέχει πολύ περισσότερα παραδείγματα "όχι spam" από "spam", επομένως το βιβλίο προτείνει να χρησιμοποιήσετε μόνο τόσα παραδείγματα "όχι spam" όσο και "spam" (αφαιρώντας έτσι από τα δεδομένα εκπαίδευσης όλα τα επιπλέον παραδείγματα). Σε αυτή την περίπτωση, αυτό ήταν 747 παραδείγματα από το καθένα.
Στη συνέχεια, το 70% του συνόλου δεδομένων χρησιμοποιείται για εκπαίδευση, το 10% για επικύρωση και το 20% για δοκιμή.
- Το σύνολο επικύρωσης χρησιμοποιείται κατά τη διάρκεια της φάσης εκπαίδευσης για να προσαρμόσει τις υπερπαραμέτρους του μοντέλου και να λάβει αποφάσεις σχετικά με την αρχιτεκτονική του μοντέλου, βοηθώντας αποτελεσματικά στην πρόληψη της υπερβολικής προσαρμογής παρέχοντας ανατροφοδότηση σχετικά με την απόδοση του μοντέλου σε αόρατα δεδομένα. Επιτρέπει επαναληπτικές βελτιώσεις χωρίς να προκαλεί προκατάληψη στην τελική αξιολόγηση.
- Αυτό σημαίνει ότι αν και τα δεδομένα που περιλαμβάνονται σε αυτό το σύνολο δεδομένων δεν χρησιμοποιούνται άμεσα για την εκπαίδευση, χρησιμοποιούνται για να ρυθμίσουν τις καλύτερες υπερπαραμέτρους, οπότε αυτό το σύνολο δεν μπορεί να χρησιμοποιηθεί για την αξιολόγηση της απόδοσης του μοντέλου όπως το σύνολο δοκιμής.
- Αντίθετα, το σύνολο δοκιμής χρησιμοποιείται μόνο μετά την πλήρη εκπαίδευση του μοντέλου και την ολοκλήρωση όλων των ρυθμίσεων; παρέχει μια αμερόληπτη εκτίμηση της ικανότητας του μοντέλου να γενικεύει σε νέα, αόρατα δεδομένα. Αυτή η τελική αξιολόγηση στο σύνολο δοκιμής δίνει μια ρεαλιστική ένδειξη του πώς αναμένεται να αποδώσει το μοντέλο σε πραγματικές εφαρμογές.
Entries length
Καθώς το παράδειγμα εκπαίδευσης αναμένει καταχωρίσεις (κείμενα email σε αυτή την περίπτωση) της ίδιας μήκους, αποφασίστηκε να γίνει κάθε καταχώριση όσο μεγάλη είναι η μεγαλύτερη προσθέτοντας τα ids του <|endoftext|>
ως padding.
Initialize the model
Χρησιμοποιώντας τα προεκπαιδευμένα βάρη ανοιχτού κώδικα, αρχικοποιήστε το μοντέλο για εκπαίδευση. Έχουμε ήδη κάνει αυτό πριν και ακολουθώντας τις οδηγίες του https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb μπορείτε εύκολα να το κάνετε.
Classification head
Σε αυτό το συγκεκριμένο παράδειγμα (προβλέποντας αν ένα κείμενο είναι spam ή όχι), δεν μας ενδιαφέρει να προσαρμόσουμε σύμφωνα με το πλήρες λεξιλόγιο του GPT2 αλλά θέλουμε μόνο το νέο μοντέλο να λέει αν το email είναι spam (1) ή όχι (0). Επομένως, θα τροποποιήσουμε την τελική στρώση που δίνει τις πιθανότητες ανά token του λεξιλογίου για μία που δίνει μόνο τις πιθανότητες του να είναι spam ή όχι (οπότε όπως ένα λεξιλόγιο 2 λέξεων).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Παράμετροι προς ρύθμιση
Για να γίνει η ρύθμιση γρήγορα, είναι πιο εύκολο να μην ρυθμιστούν όλες οι παράμετροι αλλά μόνο μερικές τελικές. Αυτό συμβαίνει επειδή είναι γνωστό ότι τα κατώτερα επίπεδα γενικά καταγράφουν βασικές γλωσσικές δομές και εφαρμοσμένη σημασιολογία. Έτσι, απλώς η ρύθμιση των τελευταίων επιπέδων είναι συνήθως αρκετή και πιο γρήγορη.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Εγγραφές για χρήση στην εκπαίδευση
Στις προηγούμενες ενότητες, το LLM εκπαιδεύτηκε μειώνοντας την απώλεια κάθε προβλεπόμενου τόνου, αν και σχεδόν όλοι οι προβλεπόμενοι τόνοι ήταν στην είσοδο της πρότασης (μόνο 1 στο τέλος ήταν πραγματικά προβλεπόμενος) προκειμένου το μοντέλο να κατανοήσει καλύτερα τη γλώσσα.
Σε αυτή την περίπτωση, μας ενδιαφέρει μόνο το μοντέλο να είναι ικανό να προβλέψει αν το μοντέλο είναι spam ή όχι, οπότε μας ενδιαφέρει μόνο ο τελευταίος τόνος που προβλέπεται. Επομένως, είναι απαραίτητο να τροποποιήσουμε τις προηγούμενες συναρτήσεις απώλειας εκπαίδευσης μας ώστε να λαμβάνουν υπόψη μόνο αυτόν τον τόνο.
Αυτό υλοποιείται στο https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb ως:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Σημειώστε ότι για κάθε παρτίδα μας ενδιαφέρει μόνο οι logits του τελευταίου προβλεπόμενου token.
Πλήρης κώδικας ταξινόμησης fine-tune GPT2
Μπορείτε να βρείτε όλο τον κώδικα για να κάνετε fine-tune το GPT2 ώστε να είναι ταξινομητής spam στο https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb