7.1. Fine-Tuning for Classification

Reading time: 9 minutes

What is

ファインチューニングは、膨大なデータから一般的な言語パターンを学習した事前学習済みモデルを取り、それを特定のタスクを実行するためやドメイン特有の言語を理解するために適応させるプロセスです。これは、モデルのトレーニングを小さなタスク特化型データセットで続けることによって達成され、新しいデータのニュアンスにより適したパラメータに調整しながら、すでに取得した広範な知識を活用します。ファインチューニングにより、モデルは新しいモデルをゼロからトレーニングすることなく、専門的なアプリケーションでより正確で関連性のある結果を提供できるようになります。

tip

テキストを「理解する」LLMの事前トレーニングは非常に高価であるため、特定のタスクを実行するためにオープンソースの事前学習済みモデルをファインチューニングする方が通常は簡単で安価です。

tip

このセクションの目的は、すでに事前学習されたモデルをファインチューニングする方法を示すことです。したがって、新しいテキストを生成するのではなく、LLMは与えられたテキストが各カテゴリに分類される確率を選択します(例えば、テキストがスパムかどうか)。

Preparing the data set

Data set size

もちろん、モデルをファインチューニングするには、LLMを専門化するために使用する構造化データが必要です。https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynbで提案された例では、GPT2は、https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zipのデータを使用して、メールがスパムかどうかを検出するようにファインチューニングされています。

このデータセットには「スパムでない」例が「スパム」よりもはるかに多く含まれているため、本書では**「スパムでない」例を「スパム」と同じ数だけのみ使用する**ことを提案しています(したがって、トレーニングデータから余分な例を削除します)。この場合、各747例でした。

次に、70%のデータセットがトレーニングに、10%検証に、20%テストに使用されます。

  • 検証セットは、トレーニングフェーズ中にモデルのハイパーパラメータをファインチューニングし、モデルアーキテクチャに関する決定を行うために使用され、見えないデータに対するモデルのパフォーマンスに関するフィードバックを提供することで、オーバーフィッティングを防ぐのに役立ちます。これは、最終評価をバイアスせずに反復的な改善を可能にします。
  • これは、このデータセットに含まれるデータが直接トレーニングには使用されないが、最良のハイパーパラメータを調整するために使用されることを意味します。したがって、このセットはテストセットのようにモデルのパフォーマンスを評価するためには使用できません。
  • 対照的に、テストセットは、モデルが完全にトレーニングされ、すべての調整が完了した後にのみ使用されます。これは、モデルが新しい見えないデータに一般化する能力を偏りなく評価します。このテストセットでの最終評価は、モデルが実際のアプリケーションでどのように機能するかの現実的な指標を提供します。

Entries length

トレーニング例は同じ長さのエントリ(この場合はメールテキスト)を期待するため、すべてのエントリを最大のものと同じ大きさにすることに決定し、<|endoftext|>のIDをパディングとして追加しました。

Initialize the model

オープンソースの事前学習済みウェイトを使用してモデルを初期化し、トレーニングを行います。これを以前に行ったことがあり、https://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynbの指示に従えば、簡単に行えます。

Classification head

この特定の例(テキストがスパムかどうかを予測する)では、GPT2の完全な語彙に基づいてファインチューニングすることには興味がなく、新しいモデルにはメールがスパム(1)かどうか(0)を言わせたいだけです。したがって、語彙の各トークンの確率を提供する最終層を、スパムかどうかの確率のみを提供するものに変更します(つまり、2語の語彙のように)。

python
# This code modified the final layer with a Linear one with 2 outs
num_classes = 2
model.out_head = torch.nn.Linear(

in_features=BASE_CONFIG["emb_dim"],

out_features=num_classes
)

調整するパラメータ

迅速にファインチューニングを行うためには、すべてのパラメータを調整するのではなく、最終的なパラメータの一部だけを調整する方が簡単です。これは、下位層が一般的に基本的な言語構造と適用可能な意味論を捉えることが知られているためです。したがって、最後の層だけをファインチューニングすることが通常は十分であり、より速いです

python
# This code makes all the parameters of the model unrtainable
for param in model.parameters():
param.requires_grad = False

# Allow to fine tune the last layer in the transformer block
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True

# Allow to fine tune the final layer norm
for param in model.final_norm.parameters():

param.requires_grad = True

トレーニングに使用するエントリ

前のセクションでは、LLMは予測されたトークンの損失を減少させることでトレーニングされましたが、ほとんどすべての予測トークンは入力文に含まれていました(実際に予測されたのは最後の1つだけ)ので、モデルが言語をよりよく理解できるようにしました。

この場合、モデルがスパムかどうかを予測できることだけが重要であるため、最後に予測されたトークンのみを考慮します。したがって、以前のトレーニング損失関数を修正して、そのトークンのみを考慮する必要があります。

これはhttps://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/ch06.ipynbで次のように実装されています:

python
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0

if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)

with torch.no_grad():
logits = model(input_batch)[:, -1, :]  # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)

num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples


def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :]  # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss

各バッチについて、最後に予測されたトークンのロジットのみに興味があることに注意してください。

完全なGPT2ファインチューニング分類コード

GPT2をスパム分類器としてファインチューニングするためのすべてのコードはhttps://github.com/rasbt/LLMs-from-scratch/blob/main/ch06/01_main-chapter-code/load-finetuned-model.ipynbで見つけることができます。

参考文献