7.0. LoRA Improvements in fine-tuning

Reading time: 2 minutes

LoRA Improvements

tip

The use of LoRA reduce a lot the computation needed to fine tune already trained models.

LoRA makes it possible to fine-tune large models efficiently by only changing a small part of the model. It reduces the number of parameters you need to train, saving memory and computational resources. This is because:

  1. Reduces the Number of Trainable Parameters: Instead of updating the entire weight matrix in the model, LoRA splits the weight matrix into two smaller matrices (called A and B). This makes training faster and requires less memory because fewer parameters need to be updated.

    1. This is because instead of calculating the complete weight update of a layer (matrix), it approximates it to a product of 2 smaller matrices reducing the update to calculate:\

  2. Keeps Original Model Weights Unchanged: LoRA allows you to keep the original model weights the same, and only updates the new small matrices (A and B). This is helpful because it means the model’s original knowledge is preserved, and you only tweak what's necessary.

  3. Efficient Task-Specific Fine-Tuning: When you want to adapt the model to a new task, you can just train the small LoRA matrices (A and B) while leaving the rest of the model as it is. This is much more efficient than retraining the entire model.

  4. Storage Efficiency: After fine-tuning, instead of saving a whole new model for each task, you only need to store the LoRA matrices, which are very small compared to the entire model. This makes it easier to adapt the model to many tasks without using too much storage.

In order to implemente LoraLayers instead of Linear ones during a fine tuning, this code is proposed here https://github.com/rasbt/LLMs-from-scratch/blob/main/appendix-E/01_main-chapter-code/appendix-E.ipynb:

python
import math

# Create the LoRA layer with the 2 matrices and the alpha
class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        self.A = torch.nn.Parameter(torch.empty(in_dim, rank))
        torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))  # similar to standard weight initialization
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

# Combine it with the linear layer
class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

# Replace linear layers with LoRA ones
def replace_linear_with_lora(model, rank, alpha):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Linear):
            # Replace the Linear layer with LinearWithLoRA
            setattr(model, name, LinearWithLoRA(module, rank, alpha))
        else:
            # Recursively apply the same function to child modules
            replace_linear_with_lora(module, rank, alpha)

References