7.1. Fine-Tuning for Classification

Tip

AWS ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ:HackTricks Training AWS Red Team Expert (ARTE)
GCP ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training GCP Red Team Expert (GRTE) Azure ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training Azure Red Team Expert (AzRTE)

HackTricks ์ง€์›ํ•˜๊ธฐ

What is

Fine-tuning์€ ๋ฐฉ๋Œ€ํ•œ ์–‘์˜ ๋ฐ์ดํ„ฐ์—์„œ ์ผ๋ฐ˜ ์–ธ์–ด ํŒจํ„ด์„ ํ•™์Šตํ•œ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ ๊ฐ€์ ธ์™€์„œ ํŠน์ • ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๊ฑฐ๋‚˜ ๋„๋ฉ”์ธ ํŠน์ • ์–ธ์–ด๋ฅผ ์ดํ•ดํ•˜๋„๋ก ์กฐ์ •ํ•˜๋Š” ๊ณผ์ •์ž…๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์˜ ํ›ˆ๋ จ์„ ๋” ์ž‘๊ณ  ์ž‘์—… ํŠน์ • ๋ฐ์ดํ„ฐ ์„ธํŠธ์—์„œ ๊ณ„์† ์ง„ํ–‰ํ•˜์—ฌ ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ์˜ ๋‰˜์•™์Šค์— ๋” ์ž˜ ๋งž๋„๋ก ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•˜๋ฉด์„œ ์ด๋ฏธ ์Šต๋“ํ•œ ํญ๋„“์€ ์ง€์‹์„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค. Fine-tuning์€ ์ƒˆ๋กœ์šด ๋ชจ๋ธ์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ํ›ˆ๋ จํ•  ํ•„์š” ์—†์ด ์ „๋ฌธํ™”๋œ ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ์—์„œ ๋” ์ •ํ™•ํ•˜๊ณ  ๊ด€๋ จ์„ฑ ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.

Tip

LLM์„ โ€œ์ดํ•ดํ•˜๋Š”โ€ ํ…์ŠคํŠธ๋กœ ์‚ฌ์ „ ํ›ˆ๋ จํ•˜๋Š” ๊ฒƒ์ด ์ƒ๋‹นํžˆ ๋น„์šฉ์ด ๋งŽ์ด ๋“ค๊ธฐ ๋•Œ๋ฌธ์—, ์ผ๋ฐ˜์ ์œผ๋กœ ์šฐ๋ฆฌ๊ฐ€ ์›ํ•˜๋Š” ํŠน์ • ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋„๋ก ์˜คํ”ˆ ์†Œ์Šค ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ fine-tuneํ•˜๋Š” ๊ฒƒ์ด ๋” ์‰ฝ๊ณ  ์ €๋ ดํ•ฉ๋‹ˆ๋‹ค.

Tip

์ด ์„น์…˜์˜ ๋ชฉํ‘œ๋Š” ์ด๋ฏธ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ์„ fine-tuneํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ฃผ๋Š” ๊ฒƒ์ด๋ฏ€๋กœ, LLM์ด ์ƒˆ๋กœ์šด ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋Œ€์‹  ์ฃผ์–ด์ง„ ํ…์ŠคํŠธ๊ฐ€ ์ฃผ์–ด์ง„ ๊ฐ ์นดํ…Œ๊ณ ๋ฆฌ์— ๋ถ„๋ฅ˜๋  ํ™•๋ฅ ์„ ์„ ํƒํ•˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค (์˜ˆ: ํ…์ŠคํŠธ๊ฐ€ ์ŠคํŒธ์ธ์ง€ ์•„๋‹Œ์ง€).

Preparing the data set

Data set size

๋ฌผ๋ก , ๋ชจ๋ธ์„ fine-tuneํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” LLM์„ ์ „๋ฌธํ™”ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ๊ตฌ์กฐํ™”๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb์—์„œ ์ œ์•ˆ๋œ ์˜ˆ์ œ์—์„œ, GPT2๋Š” https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฉ”์ผ์ด ์ŠคํŒธ์ธ์ง€ ์•„๋‹Œ์ง€๋ฅผ ๊ฐ์ง€ํ•˜๋„๋ก fine-tune๋ฉ๋‹ˆ๋‹ค.

์ด ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” โ€œ์ŠคํŒธ์ด ์•„๋‹˜โ€œ์˜ ์˜ˆ์ œ๊ฐ€ โ€œ์ŠคํŒธโ€œ๋ณด๋‹ค ํ›จ์”ฌ ๋” ๋งŽ๊ธฐ ๋•Œ๋ฌธ์—, ์ฑ…์—์„œ๋Š” โ€œ์ŠคํŒธโ€œ์˜ ์˜ˆ์ œ ์ˆ˜๋งŒํผ๋งŒ โ€œ์ŠคํŒธ์ด ์•„๋‹˜โ€œ์˜ ์˜ˆ์ œ๋ฅผ ์‚ฌ์šฉํ•˜๋ผ๊ณ  ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค (๋”ฐ๋ผ์„œ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์—์„œ ๋ชจ๋“  ์ถ”๊ฐ€ ์˜ˆ์ œ๋ฅผ ์ œ๊ฑฐํ•ฉ๋‹ˆ๋‹ค). ์ด ๊ฒฝ์šฐ, ๊ฐ 747๊ฐœ์˜ ์˜ˆ์ œ๊ฐ€ ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค.

๊ทธ๋Ÿฐ ๋‹ค์Œ, **70%**์˜ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋Š” ํ›ˆ๋ จ์—, **10%**๋Š” ๊ฒ€์ฆ์—, **20%**๋Š” ํ…Œ์ŠคํŠธ์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

  • ๊ฒ€์ฆ ์„ธํŠธ๋Š” ํ›ˆ๋ จ ๋‹จ๊ณ„์—์„œ ๋ชจ๋ธ์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ fine-tuneํ•˜๊ณ  ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์— ๋Œ€ํ•œ ๊ฒฐ์ •์„ ๋‚ด๋ฆฌ๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋ฉฐ, ๋ณด์ง€ ๋ชปํ•œ ๋ฐ์ดํ„ฐ์—์„œ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์— ๋Œ€ํ•œ ํ”ผ๋“œ๋ฐฑ์„ ์ œ๊ณตํ•˜์—ฌ ๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•˜๋Š” ๋ฐ ํšจ๊ณผ์ ์œผ๋กœ ๋„์›€์„ ์ค๋‹ˆ๋‹ค. ์ด๋Š” ์ตœ์ข… ํ‰๊ฐ€์— ํŽธํ–ฅ์„ ์ฃผ์ง€ ์•Š๊ณ  ๋ฐ˜๋ณต์ ์ธ ๊ฐœ์„ ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•ฉ๋‹ˆ๋‹ค.
  • ์ด๋Š” ์ด ๋ฐ์ดํ„ฐ ์„ธํŠธ์— ํฌํ•จ๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ์ง์ ‘์ ์œผ๋กœ ํ›ˆ๋ จ์— ์‚ฌ์šฉ๋˜์ง€ ์•Š์ง€๋งŒ, ์ตœ์ƒ์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋ฏ€๋กœ, ์ด ์„ธํŠธ๋Š” ํ…Œ์ŠคํŠธ ์„ธํŠธ์ฒ˜๋Ÿผ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Œ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
  • ๋ฐ˜๋ฉด, ํ…Œ์ŠคํŠธ ์„ธํŠธ๋Š” ๋ชจ๋ธ์ด ์™„์ „ํžˆ ํ›ˆ๋ จ๋˜๊ณ  ๋ชจ๋“  ์กฐ์ •์ด ์™„๋ฃŒ๋œ ํ›„์—๋งŒ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค; ์ด๋Š” ๋ชจ๋ธ์ด ์ƒˆ๋กœ์šด ๋ณด์ง€ ๋ชปํ•œ ๋ฐ์ดํ„ฐ์— ์ผ๋ฐ˜ํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ๋Šฅ๋ ฅ์— ๋Œ€ํ•œ ํŽธํ–ฅ ์—†๋Š” ํ‰๊ฐ€๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ํ…Œ์ŠคํŠธ ์„ธํŠธ์— ๋Œ€ํ•œ ์ด ์ตœ์ข… ํ‰๊ฐ€๋Š” ๋ชจ๋ธ์ด ์‹ค์ œ ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ์—์„œ ์–ด๋–ป๊ฒŒ ์ˆ˜ํ–‰๋  ๊ฒƒ์œผ๋กœ ์˜ˆ์ƒ๋˜๋Š”์ง€๋ฅผ ํ˜„์‹ค์ ์œผ๋กœ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.

Entries length

ํ›ˆ๋ จ ์˜ˆ์ œ๊ฐ€ ๋™์ผํ•œ ๊ธธ์ด์˜ ํ•ญ๋ชฉ(์ด ๊ฒฝ์šฐ ์ด๋ฉ”์ผ ํ…์ŠคํŠธ)์„ ๊ธฐ๋Œ€ํ•˜๋ฏ€๋กœ, ๊ฐ€์žฅ ํฐ ํ•ญ๋ชฉ๊ณผ ๋™์ผํ•œ ํฌ๊ธฐ๋กœ ๋ชจ๋“  ํ•ญ๋ชฉ์„ ๋งŒ๋“ค๊ธฐ๋กœ ๊ฒฐ์ •ํ•˜์˜€์œผ๋ฉฐ, <|endoftext|>์˜ ID๋ฅผ ํŒจ๋”ฉ์œผ๋กœ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

Initialize the model

์˜คํ”ˆ ์†Œ์Šค ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์ดˆ๊ธฐํ™”ํ•˜์—ฌ ํ›ˆ๋ จํ•ฉ๋‹ˆ๋‹ค. ์šฐ๋ฆฌ๋Š” ์ด๋ฏธ ์ด ์ž‘์—…์„ ์ˆ˜ํ–‰ํ–ˆ์œผ๋ฉฐ https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb์˜ ์ง€์นจ์„ ๋”ฐ๋ฅด๋ฉด ์‰ฝ๊ฒŒ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Classification head

์ด ํŠน์ • ์˜ˆ์ œ(ํ…์ŠคํŠธ๊ฐ€ ์ŠคํŒธ์ธ์ง€ ์•„๋‹Œ์ง€ ์˜ˆ์ธก)์—์„œ๋Š” GPT2์˜ ์ „์ฒด ์–ดํœ˜์— ๋”ฐ๋ผ fine-tuneํ•˜๋Š” ๊ฒƒ์— ๊ด€์‹ฌ์ด ์—†์œผ๋ฉฐ, ์ƒˆ ๋ชจ๋ธ์ด ์ด๋ฉ”์ผ์ด ์ŠคํŒธ(1)์ธ์ง€ ์•„๋‹Œ์ง€(0)๋งŒ ๋งํ•˜๋„๋ก ํ•˜๊ธฐ๋ฅผ ์›ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์šฐ๋ฆฌ๋Š” ์ŠคํŒธ ์—ฌ๋ถ€์˜ ํ™•๋ฅ ๋งŒ ์ œ๊ณตํ•˜๋Š” ์ตœ์ข… ๋ ˆ์ด์–ด๋ฅผ ์ˆ˜์ •ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค (์ฆ‰, 2๊ฐœ์˜ ๋‹จ์–ด๋กœ ๊ตฌ์„ฑ๋œ ์–ดํœ˜์ฒ˜๋Ÿผ).

# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

์กฐ์ •ํ•  ๋งค๊ฐœ๋ณ€์ˆ˜

๋น ๋ฅด๊ฒŒ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์ผ๋ถ€ ์ตœ์ข… ๋งค๊ฐœ๋ณ€์ˆ˜๋งŒ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์ด ๋” ์‰ฝ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ํ•˜์œ„ ๊ณ„์ธต์ด ์ผ๋ฐ˜์ ์œผ๋กœ ์ ์šฉ ๊ฐ€๋Šฅํ•œ ๊ธฐ๋ณธ ์–ธ์–ด ๊ตฌ์กฐ์™€ ์˜๋ฏธ๋ฅผ ํฌ์ฐฉํ•œ๋‹ค๋Š” ๊ฒƒ์ด ์•Œ๋ ค์ ธ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, ๋งˆ์ง€๋ง‰ ๊ณ„์ธต๋งŒ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๋ณดํ†ต ์ถฉ๋ถ„ํ•˜๊ณ  ๋” ๋น ๋ฆ…๋‹ˆ๋‹ค.

# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

Entries to use for training

์ด์ „ ์„น์…˜์—์„œ๋Š” LLM์ด ์ž…๋ ฅ ๋ฌธ์žฅ์— ๊ฑฐ์˜ ๋ชจ๋“  ์˜ˆ์ธก๋œ ํ† ํฐ์ด ํฌํ•จ๋˜์–ด ์žˆ์Œ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ (์‹ค์ œ๋กœ ์˜ˆ์ธก๋œ ๊ฒƒ์€ ๋์— ์žˆ๋Š” 1๊ฐœ๋ฟ์ž„) ๊ฐ ์˜ˆ์ธก๋œ ํ† ํฐ์˜ ์†์‹ค์„ ์ค„์ด๋Š” ๋ฐฉ์‹์œผ๋กœ ํ›ˆ๋ จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์ด ์–ธ์–ด๋ฅผ ๋” ์ž˜ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๊ธฐ ์œ„ํ•จ์ž…๋‹ˆ๋‹ค.

์ด ๊ฒฝ์šฐ ์šฐ๋ฆฌ๋Š” ๋ชจ๋ธ์ด ์ŠคํŒธ์ธ์ง€ ์•„๋‹Œ์ง€๋ฅผ ์˜ˆ์ธกํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ๋งŒ ์ค‘์š”ํ•˜๋ฏ€๋กœ, ๋งˆ์ง€๋ง‰์œผ๋กœ ์˜ˆ์ธก๋œ ํ† ํฐ๋งŒ ๊ณ ๋ คํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ด์ „์˜ ํ›ˆ๋ จ ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์ˆ˜์ •ํ•˜์—ฌ ํ•ด๋‹น ํ† ํฐ๋งŒ ๊ณ ๋ คํ•˜๋„๋ก ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด๋Š” https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb์—์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ตฌํ˜„๋ฉ๋‹ˆ๋‹ค:

def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

๊ฐ ๋ฐฐ์น˜์— ๋Œ€ํ•ด ์šฐ๋ฆฌ๋Š” ์˜ˆ์ธก๋œ ๋งˆ์ง€๋ง‰ ํ† ํฐ์˜ ๋กœ์ง“์—๋งŒ ๊ด€์‹ฌ์ด ์žˆ์Œ์„ ์ฃผ๋ชฉํ•˜์„ธ์š”.

์™„์ „ํ•œ GPT2 ๋ฏธ์„ธ ์กฐ์ • ๋ถ„๋ฅ˜ ์ฝ”๋“œ

์ŠคํŒธ ๋ถ„๋ฅ˜๊ธฐ๋กœ GPT2๋ฅผ ๋ฏธ์„ธ ์กฐ์ •ํ•˜๋Š” ๋ชจ๋“  ์ฝ”๋“œ๋Š” https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฐธ๊ณ ๋ฌธํ—Œ

Tip

AWS ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ:HackTricks Training AWS Red Team Expert (ARTE)
GCP ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training GCP Red Team Expert (GRTE) Azure ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training Azure Red Team Expert (AzRTE)

HackTricks ์ง€์›ํ•˜๊ธฐ