7.1. Ajuste fino para clasificación
Reading time: 6 minutes
Qué es
El ajuste fino es el proceso de tomar un modelo preentrenado que ha aprendido patrones de lenguaje generales a partir de grandes cantidades de datos y adaptarlo para realizar una tarea específica o para entender el lenguaje específico de un dominio. Esto se logra continuando el entrenamiento del modelo en un conjunto de datos más pequeño y específico para la tarea, lo que le permite ajustar sus parámetros para adaptarse mejor a las sutilezas de los nuevos datos mientras aprovecha el amplio conocimiento que ya ha adquirido. El ajuste fino permite que el modelo ofrezca resultados más precisos y relevantes en aplicaciones especializadas sin la necesidad de entrenar un nuevo modelo desde cero.
tip
Dado que preentrenar un LLM que "entiende" el texto es bastante costoso, generalmente es más fácil y económico ajustar modelos preentrenados de código abierto para realizar una tarea específica que queremos que realice.
tip
El objetivo de esta sección es mostrar cómo ajustar finamente un modelo que ya ha sido preentrenado, de modo que en lugar de generar nuevo texto, el LLM seleccionará y dará las probabilidades de que el texto dado sea categorizado en cada una de las categorías dadas (como si un texto es spam o no).
Preparando el conjunto de datos
Tamaño del conjunto de datos
Por supuesto, para ajustar un modelo, necesitas algunos datos estructurados para especializar tu LLM. En el ejemplo propuesto en https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb, GPT2 se ajusta para detectar si un correo electrónico es spam o no utilizando los datos de https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip.
Este conjunto de datos contiene muchos más ejemplos de "no spam" que de "spam", por lo tanto, el libro sugiere usar solo tantos ejemplos de "no spam" como de "spam" (por lo tanto, eliminando del conjunto de entrenamiento todos los ejemplos adicionales). En este caso, fueron 747 ejemplos de cada uno.
Luego, se utiliza el 70% del conjunto de datos para entrenamiento, 10% para validación y 20% para pruebas.
- El conjunto de validación se utiliza durante la fase de entrenamiento para ajustar los hiperparámetros del modelo y tomar decisiones sobre la arquitectura del modelo, ayudando efectivamente a prevenir el sobreajuste al proporcionar retroalimentación sobre cómo se desempeña el modelo en datos no vistos. Permite mejoras iterativas sin sesgar la evaluación final.
- Esto significa que, aunque los datos incluidos en este conjunto de datos no se utilizan directamente para el entrenamiento, se utilizan para ajustar los mejores hiperparámetros, por lo que este conjunto no puede ser utilizado para evaluar el rendimiento del modelo como el de pruebas.
- En contraste, el conjunto de pruebas se utiliza solo después de que el modelo ha sido completamente entrenado y todos los ajustes están completos; proporciona una evaluación imparcial de la capacidad del modelo para generalizar a nuevos datos no vistos. Esta evaluación final en el conjunto de pruebas da una indicación realista de cómo se espera que el modelo se desempeñe en aplicaciones del mundo real.
Longitud de las entradas
Dado que el ejemplo de entrenamiento espera entradas (texto de correos electrónicos en este caso) de la misma longitud, se decidió hacer que cada entrada sea tan grande como la más grande añadiendo los ids de <|endoftext|>
como relleno.
Inicializar el modelo
Usando los pesos preentrenados de código abierto, inicializa el modelo para entrenar. Ya hemos hecho esto antes y siguiendo las instrucciones de https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb puedes hacerlo fácilmente.
Cabeza de clasificación
En este ejemplo específico (prediciendo si un texto es spam o no), no estamos interesados en ajustar según el vocabulario completo de GPT2, sino que solo queremos que el nuevo modelo diga si el correo electrónico es spam (1) o no (0). Por lo tanto, vamos a modificar la capa final que da las probabilidades por token del vocabulario por una que solo da las probabilidades de ser spam o no (así que como un vocabulario de 2 palabras).
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(
in_features=BASE_CONFIG["emb_dim"],
out_features=num_classes
)
Parámetros a ajustar
Para ajustar rápidamente, es más fácil no ajustar todos los parámetros, sino solo algunos finales. Esto se debe a que se sabe que las capas inferiores generalmente capturan estructuras y semánticas básicas del lenguaje aplicables. Por lo tanto, ajustar solo las últimas capas suele ser suficiente y más rápido.
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False
# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():
param.requires_grad = True
Entradas para usar en el entrenamiento
En secciones anteriores, el LLM fue entrenado reduciendo la pérdida de cada token predicho, aunque casi todos los tokens predichos estaban en la oración de entrada (solo 1 al final fue realmente predicho) para que el modelo entendiera mejor el lenguaje.
En este caso, solo nos importa que el modelo sea capaz de predecir si el modelo es spam o no, por lo que solo nos importa el último token predicho. Por lo tanto, es necesario modificar nuestras funciones de pérdida de entrenamiento anteriores para tener en cuenta solo ese token.
Esto se implementa en https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb como:
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
with torch.no_grad():
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
Note cómo para cada lote solo estamos interesados en los logits del último token predicho.
Código completo de clasificación de ajuste fino de GPT2
Puedes encontrar todo el código para ajustar GPT2 como un clasificador de spam en https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb