7.1. Fine-Tuning for Classification

Reading time: 5 minutes

What is

Fine-tuning ni mchakato wa kuchukua modeli iliyofundishwa awali ambayo imejifunza mifumo ya lugha ya jumla kutoka kwa kiasi kikubwa cha data na kuirekebisha ili ifanye kazi maalum au kuelewa lugha maalum ya eneo. Hii inafikiwa kwa kuendelea na mafunzo ya modeli kwenye seti ndogo ya data maalum ya kazi, ikiruhusu kurekebisha vigezo vyake ili kufaa zaidi nuances za data mpya huku ikitumia maarifa mapana ambayo tayari imepata. Fine-tuning inaruhusu modeli kutoa matokeo sahihi na yanayohusiana zaidi katika matumizi maalum bila haja ya kufundisha modeli mpya kutoka mwanzo.

tip

Kwa kuwa kufundisha awali LLM ambayo "inaelewa" maandiko ni ghali sana, mara nyingi ni rahisi na nafuu kurekebisha modeli za wazi zilizofundishwa awali ili kufanya kazi maalum tunayotaka ifanye.

tip

Lengo la sehemu hii ni kuonyesha jinsi ya kurekebisha modeli iliyofundishwa awali ili badala ya kuzalisha maandiko mapya, LLM itachagua kutoa uwezekano wa maandiko yaliyotolewa kuainishwa katika kila moja ya makundi yaliyotolewa (kama maandiko ni spam au la).

Preparing the data set

Data set size

Bila shaka, ili kurekebisha modeli unahitaji data iliyopangwa ili kutumia kupecialize LLM yako. Katika mfano ulioanzishwa katika https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 inarekebishwa kugundua kama barua pepe ni spam au la kwa kutumia data kutoka https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.

Seti hii ya data ina mifano mingi zaidi ya "sio spam" kuliko "spam", kwa hivyo kitabu kinapendekeza kutumia mifano ya "sio spam" sawa na ile ya "spam" (hivyo, kuondoa mifano yote ya ziada kutoka kwa data ya mafunzo). Katika kesi hii, hii ilikuwa mifano 747 ya kila mmoja.

Kisha, 70% ya seti ya data inatumika kwa mafunzo, 10% kwa uthibitisho na 20% kwa kujaribu.

  • Seti ya uthibitisho inatumika wakati wa awamu ya mafunzo ili kurekebisha vigezo vya hyper vya modeli na kufanya maamuzi kuhusu usanifu wa modeli, kwa ufanisi kusaidia kuzuia overfitting kwa kutoa mrejesho juu ya jinsi modeli inavyofanya kwenye data isiyoonekana. Inaruhusu maboresho ya kurudi nyuma bila kupendelea tathmini ya mwisho.
  • Hii inamaanisha kwamba ingawa data iliyojumuishwa katika seti hii ya data haitumiki kwa mafunzo moja kwa moja, inatumika kurekebisha vigezo bora vya hyper, hivyo seti hii haiwezi kutumika kutathmini utendaji wa modeli kama seti ya majaribio.
  • Kinyume chake, seti ya majaribio inatumika tu baada ya modeli kufundishwa kikamilifu na marekebisho yote kukamilika; inatoa tathmini isiyo na upendeleo ya uwezo wa modeli kuweza kujumlisha kwa data mpya, isiyoonekana. Tathmini hii ya mwisho kwenye seti ya majaribio inatoa dalili halisi ya jinsi modeli inavyotarajiwa kufanya katika matumizi halisi.

Entries length

Kama mfano wa mafunzo unavyotarajia entries (maandishi ya barua pepe katika kesi hii) za urefu sawa, iliamuliwa kufanya kila entry kuwa kubwa kama ile kubwa zaidi kwa kuongeza vitambulisho vya <|endoftext|> kama padding.

Initialize the model

Kwa kutumia uzito wa wazi wa awali, anza modeli kwa mafunzo. Tayari tumefanya hivi kabla na kufuata maelekezo ya https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb unaweza kufanya hivyo kwa urahisi.

Classification head

Katika mfano huu maalum (kubashiri kama maandiko ni spam au la), hatuhitaji kurekebisha kulingana na msamiati kamili wa GPT2 bali tunataka tu modeli mpya kusema kama barua pepe ni spam (1) au la (0). Hivyo, tunakwenda kubadilisha safu ya mwisho ambayo inatoa uwezekano kwa kila token ya msamiati kwa ile inayotoa tu uwezekano wa kuwa spam au la (hivyo kama msamiati wa maneno 2).

python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

Parameters to tune

Ili kuboresha haraka ni rahisi kutokuboresha vigezo vyote bali baadhi ya vigezo vya mwisho tu. Hii ni kwa sababu inajulikana kwamba tabaka za chini kwa ujumla zinashughulikia muundo wa lugha wa kimsingi na maana zinazotumika. Hivyo, tu kuboresha tabaka za mwisho mara nyingi inatosha na ni ya haraka zaidi.

python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

Entries to use for training

Katika sehemu za awali, LLM ilifundishwa kupunguza hasara ya kila token iliyotabiriwa, ingawa karibu token zote zilizotabiriwa zilikuwa katika sentensi ya ingizo (moja tu mwishoni ilitabiriwa kwa kweli) ili mfano uweze kuelewa lugha vizuri zaidi.

Katika kesi hii, tunajali tu kuhusu mfano kuwa na uwezo wa kutabiri ikiwa mfano ni spam au la, hivyo tunajali tu kuhusu token ya mwisho iliyotabiriwa. Kwa hiyo, inahitajika kubadilisha kazi zetu za hasara za mafunzo ya awali ili kuchukua tu token hiyo katika akaunti.

Hii imewekwa katika https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb kama:

python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

Note how for each batch we are only interested in the logits of the last token predicted.

Complete GPT2 fine-tune classification code

You can find all the code to fine-tune GPT2 to be a spam classifier in https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb

References