7.1. Fine-Tuning for Classification

Reading time: 4 minutes

What is

Fine-tuning은 방대한 양의 데이터에서 일반 언어 패턴을 학습한 사전 훈련된 모델을 가져와서 특정 작업을 수행하거나 도메인 특정 언어를 이해하도록 조정하는 과정입니다. 이는 모델의 훈련을 더 작고 작업 특정 데이터 세트에서 계속 진행하여 새로운 데이터의 뉘앙스에 더 잘 맞도록 매개변수를 조정할 수 있게 하며, 이미 습득한 폭넓은 지식을 활용할 수 있게 합니다. Fine-tuning은 새로운 모델을 처음부터 훈련할 필요 없이 전문화된 애플리케이션에서 더 정확하고 관련성 있는 결과를 제공할 수 있게 합니다.

tip

LLM을 "이해하는" 텍스트로 사전 훈련하는 것이 상당히 비용이 많이 들기 때문에, 일반적으로 우리가 원하는 특정 작업을 수행하도록 오픈 소스 사전 훈련된 모델을 fine-tune하는 것이 더 쉽고 저렴합니다.

tip

이 섹션의 목표는 이미 사전 훈련된 모델을 fine-tune하는 방법을 보여주는 것입니다. 따라서 새로운 텍스트를 생성하는 대신 LLM은 주어진 텍스트가 주어진 각 카테고리에 분류될 확률을 선택하게 됩니다 (예: 텍스트가 스팸인지 아닌지).

Preparing the data set

Data set size

물론, 모델을 fine-tune하기 위해서는 LLM을 전문화하는 데 사용할 구조화된 데이터가 필요합니다. https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb에서 제안된 예제에서, GPT2는 https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip 데이터를 사용하여 이메일이 스팸인지 아닌지를 감지하도록 fine-tune됩니다.

이 데이터 세트는 "스팸이 아님"의 예제가 "스팸"보다 훨씬 더 많기 때문에, 책에서는 "스팸"의 예제 수만큼만 "스팸이 아님"의 예제를 사용하라고 제안합니다 (따라서 훈련 데이터에서 모든 추가 예제를 제거합니다). 이 경우, 각 747개의 예제가 있었습니다.

그런 다음, **70%**의 데이터 세트는 훈련에, **10%**는 검증에, **20%**는 테스트에 사용됩니다.

  • 검증 세트는 훈련 단계에서 모델의 하이퍼파라미터를 fine-tune하고 모델 아키텍처에 대한 결정을 내리는 데 사용되며, 보지 못한 데이터에서 모델의 성능에 대한 피드백을 제공하여 과적합을 방지하는 데 효과적으로 도움을 줍니다. 이는 최종 평가에 편향을 주지 않고 반복적인 개선을 가능하게 합니다.
  • 이는 이 데이터 세트에 포함된 데이터가 직접적으로 훈련에 사용되지 않지만, 최상의 하이퍼파라미터를 조정하는 데 사용되므로, 이 세트는 테스트 세트처럼 모델의 성능을 평가하는 데 사용할 수 없음을 의미합니다.
  • 반면, 테스트 세트는 모델이 완전히 훈련되고 모든 조정이 완료된 후에만 사용됩니다; 이는 모델이 새로운 보지 못한 데이터에 일반화할 수 있는 능력에 대한 편향 없는 평가를 제공합니다. 테스트 세트에 대한 이 최종 평가는 모델이 실제 애플리케이션에서 어떻게 수행될 것으로 예상되는지를 현실적으로 나타냅니다.

Entries length

훈련 예제가 동일한 길이의 항목(이 경우 이메일 텍스트)을 기대하므로, 가장 큰 항목의 크기만큼 모든 항목을 만들기로 결정하고 <|endoftext|>의 ID를 패딩으로 추가했습니다.

Initialize the model

오픈 소스 사전 훈련된 가중치를 사용하여 모델을 초기화하여 훈련합니다. 우리는 이미 이전에 이를 수행했으며 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb의 지침을 따르면 쉽게 할 수 있습니다.

Classification head

이 특정 예제(텍스트가 스팸인지 아닌지 예측)에서는 GPT2의 전체 어휘에 따라 fine-tune하는 것에 관심이 없으며, 새 모델이 이메일이 스팸(1)인지 아닌지(0)만 말하도록 하기를 원합니다. 따라서 우리는 스팸 여부의 확률만 제공하는 최종 레이어를 수정할 것입니다 (즉, 2개의 단어로 구성된 어휘처럼).

python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

조정할 매개변수

빠르게 미세 조정하기 위해서는 모든 매개변수를 조정하는 것보다 일부 최종 매개변수만 조정하는 것이 더 쉽습니다. 이는 하위 레이어가 일반적으로 기본 언어 구조와 적용 가능한 의미를 포착한다는 것이 알려져 있기 때문입니다. 따라서, 마지막 레이어만 미세 조정하는 것이 일반적으로 충분하고 더 빠릅니다.

python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

Entries to use for training

이전 섹션에서는 LLM이 입력 문장에서 거의 모든 예측된 토큰이 포함되어 있음에도 불구하고(실제로 예측된 것은 끝의 1개뿐임) 모든 예측된 토큰의 손실을 줄이도록 훈련되었습니다. 이는 모델이 언어를 더 잘 이해할 수 있도록 하기 위함입니다.

이번 경우에는 모델이 스팸인지 아닌지를 예측할 수 있는 것만 중요하므로, 우리는 마지막으로 예측된 토큰만 신경 쓰면 됩니다. 따라서 이전의 훈련 손실 함수를 수정하여 그 토큰만 고려하도록 해야 합니다.

이는 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb에서 다음과 같이 구현됩니다:

python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

각 배치에 대해 우리는 예측된 마지막 토큰의 로짓에만 관심이 있음을 주목하세요.

완전한 GPT2 미세 조정 분류 코드

스팸 분류기로 GPT2를 미세 조정하는 모든 코드는 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb에서 찾을 수 있습니다.

참고문헌