7.1. Ajuste Fino para Classificação
Reading time: 5 minutes
O que é
Ajuste fino é o processo de pegar um modelo pré-treinado que aprendeu padrões gerais de linguagem a partir de grandes quantidades de dados e adaptá-lo para realizar uma tarefa específica ou entender a linguagem específica de um domínio. Isso é alcançado continuando o treinamento do modelo em um conjunto de dados menor e específico para a tarefa, permitindo que ele ajuste seus parâmetros para se adequar melhor às nuances dos novos dados, aproveitando o amplo conhecimento que já adquiriu. O ajuste fino permite que o modelo forneça resultados mais precisos e relevantes em aplicações especializadas sem a necessidade de treinar um novo modelo do zero.
tip
Como pré-treinar um LLM que "entende" o texto é bastante caro, geralmente é mais fácil e barato ajustar modelos pré-treinados de código aberto para realizar uma tarefa específica que queremos que ele execute.
tip
O objetivo desta seção é mostrar como ajustar um modelo já pré-treinado, de modo que, em vez de gerar novo texto, o LLM selecione e forneça as probabilidades do texto dado ser categorizado em cada uma das categorias dadas (como se um texto é spam ou não).
Preparando o conjunto de dados
Tamanho do conjunto de dados
Claro, para ajustar um modelo, você precisa de alguns dados estruturados para usar para especializar seu LLM. No exemplo proposto em https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, o GPT2 é ajustado para detectar se um e-mail é spam ou não usando os dados de https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Este conjunto de dados contém muito mais exemplos de "não spam" do que de "spam", portanto, o livro sugere usar apenas tantos exemplos de "não spam" quanto de "spam" (removendo assim todos os exemplos extras dos dados de treinamento). Neste caso, foram 747 exemplos de cada.
Então, 70% do conjunto de dados é usado para treinamento, 10% para validação e 20% para teste.
- O conjunto de validação é usado durante a fase de treinamento para ajustar os hiperparâmetros do modelo e tomar decisões sobre a arquitetura do modelo, ajudando efetivamente a prevenir o overfitting ao fornecer feedback sobre como o modelo se comporta em dados não vistos. Ele permite melhorias iterativas sem enviesar a avaliação final.
- Isso significa que, embora os dados incluídos neste conjunto de dados não sejam usados para o treinamento diretamente, eles são usados para ajustar os melhores hiperparâmetros, portanto, este conjunto não pode ser usado para avaliar o desempenho do modelo como o conjunto de teste.
- Em contraste, o conjunto de teste é usado apenas após o modelo ter sido totalmente treinado e todos os ajustes estarem completos; ele fornece uma avaliação imparcial da capacidade do modelo de generalizar para novos dados não vistos. Esta avaliação final no conjunto de teste dá uma indicação realista de como o modelo deve se comportar em aplicações do mundo real.
Comprimento das entradas
Como o exemplo de treinamento espera entradas (texto de e-mails, neste caso) do mesmo comprimento, decidiu-se fazer cada entrada tão grande quanto a maior, adicionando os ids de <|endoftext|>
como preenchimento.
Inicializar o modelo
Usando os pesos pré-treinados de código aberto, inicialize o modelo para treinar. Já fizemos isso antes e seguindo as instruções de https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb você pode facilmente fazê-lo.
Cabeça de Classificação
Neste exemplo específico (prevendo se um texto é spam ou não), não estamos interessados em ajustar de acordo com o vocabulário completo do GPT2, mas queremos que o novo modelo diga se o e-mail é spam (1) ou não (0). Portanto, vamos modificar a camada final que fornece as probabilidades por token do vocabulário para uma que apenas fornece as probabilidades de ser spam ou não (então, como um vocabulário de 2 palavras).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parâmetros a ajustar
Para ajustar rapidamente, é mais fácil não ajustar todos os parâmetros, mas apenas alguns finais. Isso ocorre porque é sabido que as camadas inferiores geralmente capturam estruturas e semânticas básicas da linguagem aplicáveis. Portanto, apenas ajustar as últimas camadas geralmente é suficiente e mais rápido.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entradas a serem usadas para treinamento
Nas seções anteriores, o LLM foi treinado reduzindo a perda de cada token previsto, embora quase todos os tokens previstos estivessem na frase de entrada (apenas 1 no final foi realmente previsto) para que o modelo entendesse melhor a linguagem.
Neste caso, só nos importa que o modelo seja capaz de prever se o modelo é spam ou não, então só nos importa o último token previsto. Portanto, é necessário modificar nossas funções de perda de treinamento anteriores para levar em conta apenas esse token.
Isso é implementado em https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb como:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Note que para cada lote, estamos interessados apenas nos logits do último token previsto.
Código completo de classificação de fine-tune do GPT2
Você pode encontrar todo o código para ajustar o GPT2 para ser um classificador de spam em https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb