7.1. Fyn-afstemming vir Kategorisering
Reading time: 6 minutes
{{#include /banners/hacktricks-training.md}}
Wat is
Fyn-afstemming is die proses om 'n vooraf-geleerde model te neem wat algemene taalpatrone uit groot hoeveelhede data geleer het en dit aan te pas om 'n spesifieke taak uit te voer of om domein-spesifieke taal te verstaan. Dit word bereik deur die opleiding van die model voort te sit op 'n kleiner, taak-spesifieke dataset, wat dit toelaat om sy parameters aan te pas om beter by die nuanses van die nuwe data te pas terwyl dit die breë kennis wat dit reeds verwerf het, benut. Fyn-afstemming stel die model in staat om meer akkurate en relevante resultate in gespesialiseerde toepassings te lewer sonder die behoefte om 'n nuwe model van nuuts af op te lei.
tip
Aangesien dit redelik duur is om 'n LLM wat die teks "begryp" vooraf te leer, is dit gewoonlik makliker en goedkoper om oopbron vooraf-geleerde modelle fyn-af te stem om 'n spesifieke taak uit te voer wat ons wil hê dit moet uitvoer.
tip
Die doel van hierdie afdeling is om te wys hoe om 'n reeds vooraf-geleerde model fyn-af te stem sodat die LLM, in plaas daarvan om nuwe teks te genereer, die waarskynlikhede van die gegewe teks wat in elkeen van die gegewe kategorieë gekategoriseer word (soos of 'n teks spam is of nie) sal gee.
Voorbereiding van die dataset
Dataset grootte
Natuurlik, om 'n model fyn-af te stem, benodig jy 'n paar gestruktureerde data om jou LLM te spesialiseer. In die voorbeeld wat voorgestel word in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, word GPT2 fyn-afgestem om te detecteer of 'n e-pos spam is of nie met die data van https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Hierdie dataset bevat baie meer voorbeelde van "nie spam" as van "spam", daarom stel die boek voor om slegs soveel voorbeelde van "nie spam" as van "spam" te gebruik (daarom, om al die ekstra voorbeelde uit die opleidingsdata te verwyder). In hierdie geval was dit 747 voorbeelde van elkeen.
Toe, 70% van die dataset word gebruik vir opleiding, 10% vir validasie en 20% vir toetsing.
- Die validasieset word tydens die opleidingsfase gebruik om die model se hiperparameters fyn-af te stem en besluite te neem oor modelargitektuur, wat effektief help om oorpassing te voorkom deur terugvoer te gee oor hoe die model presteer op ongekende data. Dit stel iteratiewe verbeterings in staat sonder om die finale evaluasie te bevoordeel.
- Dit beteken dat alhoewel die data wat in hierdie dataset ingesluit is nie direk vir die opleiding gebruik word nie, dit gebruik word om die beste hiperparameters te stem, so hierdie stel kan nie gebruik word om die model se prestasie te evalueer soos die toetsstel nie.
- In teenstelling hiermee, die toetsstel word slegs na die model ten volle opgelei is en al die aanpassings voltooi is, gebruik; dit bied 'n onbevooroordeelde beoordeling van die model se vermoë om te generaliseer na nuwe, ongekende data. Hierdie finale evaluasie op die toetsstel gee 'n realistiese aanduiding van hoe die model verwag word om in werklike toepassings te presteer.
Inskrywings lengte
Aangesien die opleidingsvoorbeeld inskrywings (e-pos teks in hierdie geval) van dieselfde lengte verwag, is daar besluit om elke inskrywing so groot te maak soos die grootste een deur die id's van <|endoftext|>
as opvulling by te voeg.
Begin die model
Gebruik die oopbron vooraf-geleerde gewigte om die model te begin oplei. Ons het dit al voorheen gedoen en volg die instruksies van https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb kan jy dit maklik doen.
Kategorisering kop
In hierdie spesifieke voorbeeld (voorspel of 'n teks spam is of nie), is ons nie geïnteresseerd in fyn-afstemming volgens die volledige woordeskat van GPT2 nie, maar ons wil net hê die nuwe model moet sê of die e-pos spam is (1) of nie (0). Daarom gaan ons die laaste laag wat die waarskynlikhede per token van die woordeskat gee, aanpas vir een wat slegs die waarskynlikhede van spam of nie spam gee (soos 'n woordeskat van 2 woorde).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parameters om te stel
Om vinnig te fynstel is dit makliker om nie al die parameters te fynstel nie, maar slegs 'n paar finale. Dit is omdat dit bekend is dat die laer lae oor die algemeen basiese taalstrukture en toepaslike semantiek vasvang. So, net die laaste lae fynstel is gewoonlik genoeg en vinniger.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entries to use for training
In vorige afdelings is die LLM opgelei deur die verlies van elke voorspelde token te verminder, alhoewel byna al die voorspelde tokens in die invoer sin was (slegs 1 aan die einde was werklik voorspel) sodat die model die taal beter kan verstaan.
In hierdie geval is ons net geïnteresseerd in die model se vermoë om te voorspel of die model spam is of nie, so ons is net geïnteresseerd in die laaste voorspelde token. Daarom is dit nodig om ons vorige opleidingsverlies funksies te wysig om slegs daardie token in ag te neem.
Dit is geïmplementeer in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb as:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Let op hoe ons vir elke bondel slegs belangstel in die logits van die laaste token wat voorspel is.
Volledige GPT2 fyn-afstemming klassifikasie kode
Jy kan al die kode vind om GPT2 te fyn-afstem om 'n spam klassifiseerder te wees in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb
Verwysings
{{#include /banners/hacktricks-training.md}}