7.0. LoRAのファインチューニングにおける改善

Reading time: 4 minutes

LoRAの改善

tip

LoRAを使用することで、ファインチューニングに必要な計算が大幅に削減されます

LoRAは、モデルの小さな部分のみを変更することで、大規模モデルを効率的にファインチューニングすることを可能にします。これにより、トレーニングに必要なパラメータの数が減り、メモリ計算リソースが節約されます。これは以下の理由によります:

  1. トレーニング可能なパラメータの数を削減: モデル内の全体の重み行列を更新する代わりに、LoRAは重み行列を2つの小さな行列(AB)に分割します。これにより、トレーニングが速くなり、更新する必要のあるパラメータが少ないため、メモリ少なくて済みます。

  2. これは、レイヤー(行列)の完全な重み更新を計算する代わりに、2つの小さな行列の積に近似するため、計算する更新が減少するからです:\

  1. 元のモデルの重みを変更しない: LoRAを使用すると、元のモデルの重みをそのままにしておき、新しい小さな行列(AとB)だけを更新できます。これは、モデルの元の知識が保持され、必要な部分だけを調整することを意味するため、便利です。
  2. 効率的なタスク特化型ファインチューニング: モデルを新しいタスクに適応させたい場合、モデルの残りの部分をそのままにしておき、小さなLoRA行列(AとB)だけをトレーニングすればよいです。これは、モデル全体を再トレーニングするよりもはるかに効率的です。
  3. ストレージ効率: ファインチューニング後、各タスクのために新しいモデル全体を保存する代わりに、LoRA行列だけを保存すればよく、これはモデル全体に比べて非常に小さいです。これにより、多くのタスクにモデルを適応させる際に、過剰なストレージを使用せずに済みます。

ファインチューニング中にLinearの代わりにLoraLayersを実装するために、ここで提案されているコードがあります https://github.com/rasbt/LLMs-from-scratch/blob/main/appendix-E/01_main-chapter-code/appendix-E.ipynb:

python
import math

# Create the LoRA layer with the 2 matrices and the alpha
class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))  # similar to standard weight initialization
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha

def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x

# Combine it with the linear layer
class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)

def forward(self, x):
return self.linear(x) + self.lora(x)

# Replace linear layers with LoRA ones
def replace_linear_with_lora(model, rank, alpha):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
# Replace the Linear layer with LinearWithLoRA
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
# Recursively apply the same function to child modules
replace_linear_with_lora(module, rank, alpha)

参考文献