7.1. Fine-Tuning for Classification
Reading time: 7 minutes
tip
Jifunze na fanya mazoezi ya AWS Hacking:HackTricks Training AWS Red Team Expert (ARTE)
Jifunze na fanya mazoezi ya GCP Hacking: HackTricks Training GCP Red Team Expert (GRTE)
Jifunze na fanya mazoezi ya Azure Hacking:
HackTricks Training Azure Red Team Expert (AzRTE)
Support HackTricks
- Angalia mpango wa usajili!
- Jiunge na 💬 kikundi cha Discord au kikundi cha telegram au tufuatilie kwenye Twitter 🐦 @hacktricks_live.
- Shiriki mbinu za hacking kwa kuwasilisha PRs kwa HackTricks na HackTricks Cloud repos za github.
What is
Fine-tuning ni mchakato wa kuchukua modeli iliyofundishwa awali ambayo imejifunza mifumo ya lugha ya jumla kutoka kwa kiasi kikubwa cha data na kuirekebisha ili ifanye kazi maalum au kuelewa lugha maalum ya eneo. Hii inafikiwa kwa kuendelea na mafunzo ya modeli kwenye seti ndogo ya data maalum ya kazi, ikiruhusu kurekebisha vigezo vyake ili kufaa zaidi nuances za data mpya huku ikitumia maarifa makubwa ambayo tayari imepata. Fine-tuning inaruhusu modeli kutoa matokeo sahihi na yanayohusiana zaidi katika matumizi maalum bila haja ya kufundisha modeli mpya kutoka mwanzo.
tip
Kwa kuwa kufundisha awali LLM ambayo "inaelewa" maandiko ni ghali sana, mara nyingi ni rahisi na nafuu kurekebisha modeli za wazi zilizofundishwa awali ili kufanya kazi maalum tunayotaka ifanye.
tip
Lengo la sehemu hii ni kuonyesha jinsi ya kurekebisha modeli iliyofundishwa awali ili badala ya kuzalisha maandiko mapya, LLM itachagua kutoa uwezekano wa maandiko yaliyotolewa kuainishwa katika kila moja ya makundi yaliyotolewa (kama maandiko ni spam au la).
Preparing the data set
Data set size
Bila shaka, ili kurekebisha modeli unahitaji data iliyopangwa ili kutumia kupecialize LLM yako. Katika mfano ulioanzishwa katika https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 inarekebishwa kugundua kama barua pepe ni spam au la kwa kutumia data kutoka https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Seti hii ya data ina mifano mingi zaidi ya "sio spam" kuliko "spam", kwa hivyo kitabu kinapendekeza kutumia mifano ya "sio spam" sawa na ile ya "spam" (hivyo, kuondoa mifano yote ya ziada kutoka kwa data ya mafunzo). Katika kesi hii, hii ilikuwa mifano 747 ya kila mmoja.
Kisha, 70% ya seti ya data inatumika kwa mafunzo, 10% kwa uthibitisho na 20% kwa kujaribu.
- Seti ya uthibitisho inatumika wakati wa awamu ya mafunzo ili kurekebisha vigezo vya hyper vya modeli na kufanya maamuzi kuhusu usanifu wa modeli, kwa ufanisi kusaidia kuzuia overfitting kwa kutoa mrejesho juu ya jinsi modeli inavyofanya kwenye data isiyoonekana. Inaruhusu maboresho ya kurudi nyuma bila kupendelea tathmini ya mwisho.
- Hii ina maana kwamba ingawa data iliyojumuishwa katika seti hii ya data haitumiki kwa mafunzo moja kwa moja, inatumika kurekebisha vigezo bora vya hyper, hivyo seti hii haiwezi kutumika kutathmini utendaji wa modeli kama ile ya kujaribu.
- Kinyume chake, seti ya kujaribu inatumika tu baada ya modeli kufundishwa kikamilifu na marekebisho yote kukamilika; inatoa tathmini isiyo na upendeleo ya uwezo wa modeli kuweza kujumlisha kwa data mpya, isiyoonekana. Tathmini hii ya mwisho kwenye seti ya kujaribu inatoa dalili halisi ya jinsi modeli inavyotarajiwa kufanya katika matumizi halisi.
Entries length
Kama mfano wa mafunzo unavyotarajia entries (maandishi ya barua pepe katika kesi hii) za urefu sawa, iliamuliwa kufanya kila entry kuwa kubwa kama ile kubwa zaidi kwa kuongeza vitambulisho vya <|endoftext|>
kama padding.
Initialize the model
Kwa kutumia uzito wa wazi wa awali, anza modeli kwa mafunzo. Tayari tumefanya hivi kabla na kufuata maelekezo ya https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb unaweza kufanya hivyo kwa urahisi.
Classification head
Katika mfano huu maalum (kubashiri kama maandiko ni spam au la), hatuhitaji kurekebisha kulingana na msamiati kamili wa GPT2 bali tunataka tu modeli mpya kusema kama barua pepe ni spam (1) au la (0). Hivyo, tunakwenda kubadilisha safu ya mwisho ambayo inatoa uwezekano kwa kila token ya msamiati kwa ile inayotoa tu uwezekano wa kuwa spam au la (hivyo kama msamiati wa maneno 2).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parameters to tune
Ili kuboresha haraka ni rahisi kutokuboresha vigezo vyote bali baadhi ya vigezo vya mwisho tu. Hii ni kwa sababu inajulikana kwamba tabaka za chini kwa ujumla zinashughulikia muundo wa lugha wa kimsingi na maana zinazotumika. Hivyo, tu kuboresha tabaka za mwisho mara nyingi inatosha na ni ya haraka zaidi.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entries to use for training
Katika sehemu za awali, LLM ilifundishwa kupunguza hasara ya kila token iliyotabiriwa, ingawa karibu token zote zilizotabiriwa zilikuwa katika sentensi ya ingizo (moja tu mwishoni ilitabiriwa kwa kweli) ili modeli iweze kuelewa lugha vizuri zaidi.
Katika kesi hii, tunajali tu kuhusu uwezo wa modeli kutabiri ikiwa modeli ni spam au la, hivyo tunajali tu kuhusu token ya mwisho iliyotabiriwa. Kwa hiyo, inahitajika kubadilisha kazi zetu za hasara za mafunzo ya awali ili kuchukua tu token hiyo katika akaunti.
Hii imewekwa katika https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb kama:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Note jinsi kwa kila kundi tunavokuwa na nia tu na logits za token ya mwisho iliyotabiriwa.
Kamili GPT2 fine-tune classification code
Unaweza kupata msimbo wote wa fine-tune GPT2 kuwa mchanganuzi wa spam katika https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb
Marejeo
tip
Jifunze na fanya mazoezi ya AWS Hacking:HackTricks Training AWS Red Team Expert (ARTE)
Jifunze na fanya mazoezi ya GCP Hacking: HackTricks Training GCP Red Team Expert (GRTE)
Jifunze na fanya mazoezi ya Azure Hacking:
HackTricks Training Azure Red Team Expert (AzRTE)
Support HackTricks
- Angalia mpango wa usajili!
- Jiunge na 💬 kikundi cha Discord au kikundi cha telegram au tufuatilie kwenye Twitter 🐦 @hacktricks_live.
- Shiriki mbinu za hacking kwa kuwasilisha PRs kwa HackTricks na HackTricks Cloud repos za github.