7.1. Fine-Tuning for Classification

tip

AWS 해킹 배우기 및 연습하기:HackTricks Training AWS Red Team Expert (ARTE)
GCP 해킹 배우기 및 연습하기: HackTricks Training GCP Red Team Expert (GRTE) Azure 해킹 배우기 및 연습하기: HackTricks Training Azure Red Team Expert (AzRTE)

HackTricks 지원하기

What is

Fine-tuning은 방대한 양의 데이터에서 일반 언어 패턴을 학습한 사전 훈련된 모델을 가져와서 특정 작업을 수행하거나 도메인 특정 언어를 이해하도록 조정하는 과정입니다. 이는 모델의 훈련을 더 작고 작업 특정 데이터 세트에서 계속 진행하여 새로운 데이터의 뉘앙스에 더 잘 맞도록 매개변수를 조정하면서 이미 습득한 폭넓은 지식을 활용할 수 있게 합니다. Fine-tuning은 새로운 모델을 처음부터 훈련할 필요 없이 전문화된 응용 프로그램에서 더 정확하고 관련성 있는 결과를 제공할 수 있게 합니다.

tip

LLM을 "이해하는" 텍스트로 사전 훈련하는 것이 상당히 비용이 많이 들기 때문에, 일반적으로 우리가 원하는 특정 작업을 수행하도록 오픈 소스 사전 훈련된 모델을 fine-tune하는 것이 더 쉽고 저렴합니다.

tip

이 섹션의 목표는 이미 사전 훈련된 모델을 fine-tune하는 방법을 보여주는 것이므로, LLM이 새로운 텍스트를 생성하는 대신 주어진 텍스트가 주어진 각 카테고리에 분류될 확률을 선택하도록 합니다 (예: 텍스트가 스팸인지 아닌지).

Preparing the data set

Data set size

물론, 모델을 fine-tune하기 위해서는 LLM을 전문화하는 데 사용할 구조화된 데이터가 필요합니다. https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb에서 제안된 예제에서, GPT2는 https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip 데이터를 사용하여 이메일이 스팸인지 아닌지를 감지하도록 fine-tune됩니다.

이 데이터 세트는 "스팸이 아님"의 예제가 "스팸"보다 훨씬 더 많기 때문에, 책에서는 "스팸"의 예제 수만큼만 "스팸이 아님"의 예제를 사용하라고 제안합니다 (따라서 훈련 데이터에서 모든 추가 예제를 제거합니다). 이 경우, 각 747개의 예제가 있었습니다.

그런 다음, **70%**의 데이터 세트는 훈련에, **10%**는 검증에, **20%**는 테스트에 사용됩니다.

  • 검증 세트는 훈련 단계에서 모델의 하이퍼파라미터를 fine-tune하고 모델 아키텍처에 대한 결정을 내리는 데 사용되며, 보지 못한 데이터에서 모델의 성능에 대한 피드백을 제공하여 과적합을 방지하는 데 효과적으로 도움을 줍니다. 이는 최종 평가에 편향을 주지 않고 반복적인 개선을 가능하게 합니다.
  • 이는 이 데이터 세트에 포함된 데이터가 직접적으로 훈련에 사용되지 않지만, 최상의 하이퍼파라미터를 조정하는 데 사용되므로, 이 세트는 테스트 세트처럼 모델의 성능을 평가하는 데 사용할 수 없음을 의미합니다.
  • 반면, 테스트 세트는 모델이 완전히 훈련되고 모든 조정이 완료된 후에만 사용됩니다; 이는 모델이 새로운 보지 못한 데이터에 일반화할 수 있는 능력에 대한 편향 없는 평가를 제공합니다. 테스트 세트에 대한 이 최종 평가는 모델이 실제 응용 프로그램에서 어떻게 수행될 것으로 예상되는지를 현실적으로 나타냅니다.

Entries length

훈련 예제가 동일한 길이의 항목(이 경우 이메일 텍스트)을 기대하므로, 가장 큰 항목과 동일한 크기로 모든 항목을 만들기로 결정하였으며, <|endoftext|>의 ID를 패딩으로 추가합니다.

Initialize the model

오픈 소스 사전 훈련된 가중치를 사용하여 모델을 초기화하여 훈련합니다. 우리는 이미 이 작업을 수행했으며 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb의 지침을 따르면 쉽게 할 수 있습니다.

Classification head

이 특정 예제(텍스트가 스팸인지 아닌지 예측)에서는 GPT2의 전체 어휘에 따라 fine-tune하는 것에 관심이 없으며, 새 모델이 이메일이 스팸(1)인지 아닌지(0)만 말하도록 하기를 원합니다. 따라서 우리는 스팸 여부의 확률만 제공하는 최종 레이어를 수정할 것입니다 (즉, 2개의 단어로 구성된 어휘처럼).

python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

조정할 매개변수

빠르게 미세 조정하기 위해서는 모든 매개변수를 조정하는 것보다 일부 최종 매개변수만 조정하는 것이 더 쉽습니다. 이는 하위 계층이 일반적으로 적용 가능한 기본 언어 구조와 의미를 포착한다는 것이 알려져 있기 때문입니다. 따라서, 마지막 계층만 미세 조정하는 것으로 보통 충분하고 더 빠릅니다.

python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

Entries to use for training

이전 섹션에서는 LLM이 입력 문장에 거의 모든 예측된 토큰이 포함되어 있음에도 불구하고(실제로 예측된 것은 끝에 있는 1개뿐임) 각 예측된 토큰의 손실을 줄이는 방식으로 훈련되었습니다. 이는 모델이 언어를 더 잘 이해할 수 있도록 하기 위함입니다.

이 경우 우리는 모델이 스팸인지 아닌지를 예측할 수 있는 것만 중요하므로, 마지막으로 예측된 토큰만 고려합니다. 따라서 이전의 훈련 손실 함수를 수정하여 해당 토큰만 고려하도록 해야 합니다.

이는 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb에서 다음과 같이 구현됩니다:

python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

각 배치에 대해 우리는 예측된 마지막 토큰의 로짓에만 관심이 있음을 주목하세요.

완전한 GPT2 미세 조정 분류 코드

스팸 분류기로 GPT2를 미세 조정하는 모든 코드는 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb에서 찾을 수 있습니다.

참고문헌

tip

AWS 해킹 배우기 및 연습하기:HackTricks Training AWS Red Team Expert (ARTE)
GCP 해킹 배우기 및 연습하기: HackTricks Training GCP Red Team Expert (GRTE) Azure 해킹 배우기 및 연습하기: HackTricks Training Azure Red Team Expert (AzRTE)

HackTricks 지원하기