📕 LoRA: Low-Rank Adaptation of Large Language Models
:PROPERTIES:(1 field)
LoRA: Low-Rank Adaptation of Large Language Models
Abstract
An important paradigm of natural language processing consists of large-scale pre-training on general domain data and adaptation to particular tasks or domains. As we pre-train larger models, full fine-tuning, which retrains all model parameters, becomes less feasible. Using GPT-3 175B as an example -- deploying independent instances of fine-tuned models, each with 175B parameters, is prohibitively expensive. We propose Low-Rank Adaptation, or LoRA, which freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture, greatly reducing the number of trainable parameters for downstream tasks. Compared to GPT-3 175B fine-tuned with Adam, LoRA can reduce the number of trainable parameters by 10,000 times and the GPU memory requirement by 3 times. LoRA performs on-par or better than fine-tuning in model quality on RoBERTa, DeBERTa, GPT-2, and GPT-3, despite having fewer trainable parameters, a higher training throughput, and, unlike adapters, no additional inference latency. We also provide an empirical investigation into rank-deficiency in language model adaptation, which sheds light on the efficacy of LoRA. We release a package that facilitates the integration of LoRA with PyTorch models and provide our implementations and model checkpoints for RoBERTa, DeBERTa, and GPT-2 at https://github.com/microsoft/LoRA. url
Notes
- During training, they initialize A to normal distribution $N(0, \sigma^2)$ and B to 0, so the first output will be the exact same as the original model
- A downscales the input to rank k and B upscales it back to original dimension.
- they apply the low rank adaptation to Wq and Wv. -> this changes in future papers
- during finetuning, most of the model is frozen and only LORA weights are trained (a fraction of original parameters)
- The formulation allows plug and play applying and removing LoRA to models based on the need As mentioned in Section 4.2, we only apply LoRA to Wq and Wv in most experiments for simplicity
Common values
- 8
- 16
- 32
- 64
A_k 2056 B_k 2304 Original k matrix 65536 trainable parameters per layer: 4360 or 6.65283203125 percent
Show code
# h = W0x + ∆W x = W0x + BAx
import torch
import torch.nn as nn
class SelfAttentionLORAHead(nn.Module):
def __init__(self, input_embedding_size, output_size, dropout_p=0.9):
super().__init__()
self.key = nn.Linear(input_embedding_size, output_size, bias=False)
self.query = nn.Linear(input_embedding_size, output_size, bias=False)
self.value = nn.Linear(input_embedding_size, output_size, bias=False)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, ):
B,T,C = x.shape
delta_k = B_k * A_k * x
delta_v = B_v * A_v * x
# project input x to key and query using learned matrix
k = self.key(x) + delta_k # (B,T,C)
q = self.query(x) + delta_v # (B,T,C)
k_T = k.transpose(1 ,2) # (B, C, T)
# compute attention scores
scores = torch.matmul(q, k_T) # (B,T,C) @ (B,C,T) = (B, T, T)
scaled_scores = scores * C**-0.5 # sqrt(C) aka input embedding dim
# in Encoder head, all tokens can query all tokens so no masking is done
scaled_scores = F.softmax(scaled_scores, dim=-1) # (B, T, T)
scaled_scores = self.dropout(scaled_scores)
# perform the weighted aggregation of the values
v = self.value(x) # (B,T,C)
out = scaled_scores @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
return out
input_dim = 256
T = 20
k = 8
x = torch.rand(1, T, input_dim) # batch 1, length 20, 16 dim
key = nn.Linear(input_dim, input_dim, bias=False)
query = nn.Linear(input_dim, input_dim, bias=False)
value = nn.Linear(input_dim, input_dim, bias=False)
A_k = nn.Linear(input_dim, k) # Downscale from 256 to 8
B_k = nn.Linear(k, input_dim) # upscale from 8 to 256
delta_k = B_k(A_k(x)) # B_k * A_k * x
k = key(x)
k_star = k + delta_k
A_k_size = sum(p.numel() for p in A_k.parameters())
B_k_size = sum(p.numel() for p in B_k.parameters())
k_size = sum(p.numel() for p in key.parameters())
print("A_k", A_k_size)
print("B_k", B_k_size)
print("Original k matrix", k_size)
print("trainable parameters per layer:", B_k_size + A_k_size, "or", 100* (B_k_size + A_k_size) / k_size, "percent")
