7.1. Fine-Tuning for Classification

Reading time: 7 minutes

What is

微调是将一个预训练模型的过程,该模型已经从大量数据中学习了通用语言模式,并将其调整以执行特定任务或理解特定领域的语言。这是通过在一个较小的、特定任务的数据集上继续训练模型来实现的,使其能够调整参数以更好地适应新数据的细微差别,同时利用其已经获得的广泛知识。微调使模型能够在专业应用中提供更准确和相关的结果,而无需从头开始训练一个新模型。

tip

由于预训练一个“理解”文本的LLM相当昂贵,因此通常更容易和便宜地微调开源的预训练模型以执行我们希望其执行的特定任务。

tip

本节的目标是展示如何微调一个已经预训练的模型,因此LLM将选择给出给定文本被分类到每个给定类别的概率(例如,文本是否为垃圾邮件)。

Preparing the data set

Data set size

当然,为了微调模型,您需要一些结构化数据来专门化您的LLM。在https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb中提出的示例中,GPT2被微调以检测电子邮件是否为垃圾邮件,使用的数据来自https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip_。

该数据集包含的“非垃圾邮件”示例远多于“垃圾邮件”示例,因此本书建议仅使用与“垃圾邮件”相同数量的“非垃圾邮件”示例(因此,从训练数据中删除所有额外示例)。在这种情况下,每种情况有747个示例。

然后,70%的数据集用于训练10%用于验证20%用于测试

  • 验证集在训练阶段用于微调模型的超参数并做出关于模型架构的决策,有效地通过提供模型在未见数据上的表现反馈来帮助防止过拟合。它允许在不偏倚最终评估的情况下进行迭代改进。
  • 这意味着尽管此数据集中包含的数据不直接用于训练,但它用于调整最佳超参数,因此该集不能像测试集那样用于评估模型的性能。
  • 相比之下,测试集仅在模型完全训练并完成所有调整后使用;它提供了对模型在新未见数据上泛化能力的无偏评估。对测试集的最终评估提供了模型在实际应用中预期表现的现实指示。

Entries length

由于训练示例期望条目(在这种情况下为电子邮件文本)具有相同的长度,因此决定通过添加<|endoftext|>的ID作为填充,使每个条目与最大条目一样大。

Initialize the model

使用开源预训练权重初始化模型以进行训练。我们之前已经做过这个,并按照https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb的说明,您可以轻松做到这一点。

Classification head

在这个特定示例中(预测文本是否为垃圾邮件),我们并不关心根据GPT2的完整词汇进行微调,而只希望新模型能够判断电子邮件是否为垃圾邮件(1)或不是(0)。因此,我们将修改最终层,使其提供每个词汇的token的概率,改为仅提供是否为垃圾邮件的概率(就像一个包含2个单词的词汇)。

python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

参数调整

为了快速微调,最好只微调一些最终参数,而不是所有参数。这是因为已知较低层通常捕捉基本的语言结构和适用的语义。因此,仅微调最后几层通常就足够且更快

python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

用于训练的条目

在之前的部分中,LLM通过减少每个预测标记的损失进行训练,尽管几乎所有预测的标记都在输入句子中(只有最后一个是真正预测的),以便模型更好地理解语言。

在这种情况下,我们只关心模型是否能够预测该模型是否为垃圾邮件,因此我们只关心最后一个预测的标记。因此,需要修改我们之前的训练损失函数,仅考虑该标记。

这在 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynb 中实现为:

python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

注意,对于每个批次,我们只对最后一个预测的标记的logits感兴趣。

完整的GPT2微调分类代码

您可以在https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynb找到所有微调GPT2以成为垃圾邮件分类器的代码。

参考文献