4. 注意机制
Reading time: 17 minutes
神经网络中的注意机制和自注意力
注意机制允许神经网络在生成输出的每个部分时专注于输入的特定部分。它们为不同的输入分配不同的权重,帮助模型决定哪些输入与当前任务最相关。这在机器翻译等任务中至关重要,因为理解整个句子的上下文对于准确翻译是必要的。
tip
这一阶段的目标非常简单:应用一些注意机制。这些将是许多重复的层,将捕捉词汇中一个词与当前用于训练LLM的句子中其邻居的关系。
为此使用了很多层,因此将有很多可训练的参数来捕捉这些信息。
理解注意机制
在传统的序列到序列模型中,模型将输入序列编码为固定大小的上下文向量。然而,这种方法在处理长句子时会遇到困难,因为固定大小的上下文向量可能无法捕捉所有必要的信息。注意机制通过允许模型在生成每个输出标记时考虑所有输入标记来解决这一限制。
示例:机器翻译
考虑将德语句子 "Kannst du mir helfen diesen Satz zu übersetzen" 翻译成英语。逐字翻译不会产生语法正确的英语句子,因为不同语言之间的语法结构存在差异。注意机制使模型在生成输出句子的每个单词时能够专注于输入句子的相关部分,从而导致更准确和连贯的翻译。
自注意力介绍
自注意力或内部注意力是一种机制,其中注意力在单个序列内应用,以计算该序列的表示。它允许序列中的每个标记关注所有其他标记,帮助模型捕捉标记之间的依赖关系,而不管它们在序列中的距离。
关键概念
- 标记:输入序列的单个元素(例如,句子中的单词)。
- 嵌入:标记的向量表示,捕捉语义信息。
- 注意权重:确定每个标记相对于其他标记重要性的值。
计算注意权重:逐步示例
让我们考虑句子 "Hello shiny sun!" 并用3维嵌入表示每个单词:
- Hello:
[0.34, 0.22, 0.54]
- shiny:
[0.53, 0.34, 0.98]
- sun:
[0.29, 0.54, 0.93]
我们的目标是使用自注意力计算**"shiny"的上下文向量**。
步骤1:计算注意分数
tip
只需将查询的每个维度值与每个标记的相关维度相乘并加上结果。你将为每对标记获得1个值。
对于句子中的每个单词,通过计算它们嵌入的点积来计算与 "shiny" 的注意分数。
"Hello" 和 "shiny" 之间的注意分数
"shiny" 和 "shiny" 之间的注意分数
"sun" 和 "shiny" 之间的注意分数
步骤2:归一化注意分数以获得注意权重
tip
不要迷失在数学术语中,这个函数的目标很简单,归一化所有权重,使它们的总和为1。
此外,softmax 函数被使用,因为它通过指数部分强调差异,使得更容易检测有用的值。
应用softmax函数将注意分数转换为总和为1的注意权重。
计算指数:
计算总和:
计算注意权重:
步骤3:计算上下文向量
tip
只需获取每个注意权重并将其乘以相关标记的维度,然后将所有维度相加以获得一个向量(上下文向量)
上下文向量是通过使用注意权重对所有单词的嵌入进行加权求和计算得出的。
计算每个分量:
- "Hello" 的加权嵌入:
- "shiny" 的加权嵌入:
- "sun" 的加权嵌入:
加权嵌入的总和:
context vector=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]
这个上下文向量表示了“shiny”的丰富嵌入,结合了句子中所有单词的信息。
过程总结
- 计算注意分数:使用目标单词的嵌入与序列中所有单词的嵌入之间的点积。
- 归一化分数以获得注意权重:对注意分数应用softmax函数以获得总和为1的权重。
- 计算上下文向量:将每个单词的嵌入乘以其注意权重并求和结果。
带可训练权重的自注意力
在实践中,自注意力机制使用可训练权重来学习查询、键和值的最佳表示。这涉及引入三个权重矩阵:
查询是像以前一样使用的数据,而键和值矩阵只是随机可训练的矩阵。
步骤1:计算查询、键和值
每个标记将通过将其维度值与定义的矩阵相乘来拥有自己的查询、键和值矩阵:
这些矩阵将原始嵌入转换为适合计算注意力的新空间。
示例
假设:
- 输入维度
din=3
(嵌入大小) - 输出维度
dout=2
(查询、键和值的期望维度)
初始化权重矩阵:
import torch.nn as nn
d_in = 3
d_out = 2
W_query = nn.Parameter(torch.rand(d_in, d_out))
W_key = nn.Parameter(torch.rand(d_in, d_out))
W_value = nn.Parameter(torch.rand(d_in, d_out))
计算查询、键和值:
queries = torch.matmul(inputs, W_query)
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)
第2步:计算缩放点积注意力
计算注意力分数
与之前的示例类似,但这次我们使用的是令牌的键矩阵(已经使用维度计算得出),而不是令牌的维度值。因此,对于每个查询 qi
和键 kj
:
缩放分数
为了防止点积变得过大,通过键维度 dk
的平方根来缩放它们:
tip
分数除以维度的平方根是因为点积可能变得非常大,这有助于调节它们。
应用Softmax以获得注意力权重: 如最初示例中所示,规范化所有值,使它们的总和为1。
第3步:计算上下文向量
与最初的示例一样,只需将所有值矩阵相加,每个值乘以其注意力权重:
代码示例
从 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb 获取一个示例,您可以查看这个实现我们所讨论的自注意力功能的类:
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
import torch.nn as nn
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
d_in=3
d_out=2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
note
请注意,nn.Linear
用于将所有权重标记为可训练参数,而不是用随机值初始化矩阵。
因果注意力:隐藏未来词汇
对于 LLM,我们希望模型仅考虑当前位之前出现的标记,以便 预测下一个标记。因果注意力,也称为 掩蔽注意力,通过修改注意力机制来防止访问未来标记,从而实现这一点。
应用因果注意力掩码
为了实现因果注意力,我们在 softmax 操作之前 对注意力分数应用掩码,以便剩余的分数仍然相加为 1。该掩码将未来标记的注意力分数设置为负无穷,确保在 softmax 之后,它们的注意力权重为零。
步骤
- 计算注意力分数:与之前相同。
- 应用掩码:使用一个上三角矩阵,在对角线以上填充负无穷。
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
masked_scores = attention_scores + mask
- 应用 Softmax:使用掩蔽分数计算注意力权重。
attention_weights = torch.softmax(masked_scores, dim=-1)
使用 Dropout 掩蔽额外的注意力权重
为了 防止过拟合,我们可以在 softmax 操作后对注意力权重应用 dropout。Dropout 在训练期间随机将一些注意力权重置为零。
dropout = nn.Dropout(p=0.5)
attention_weights = dropout(attention_weights)
常规的 dropout 率约为 10-20%。
代码示例
来自 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb 的代码示例:
import torch
import torch.nn as nn
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token
keys = self.W_key(x) # This generates the keys of the tokens
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
扩展单头注意力到多头注意力
多头注意力 在实际操作中是执行 多个实例 的自注意力函数,每个实例都有 自己的权重,因此计算出不同的最终向量。
代码示例
可以重用之前的代码,只需添加一个包装器来多次启动它,但这是一个更优化的版本,来自 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb,它同时处理所有头(减少了昂贵的 for 循环数量)。正如您在代码中看到的,每个标记的维度根据头的数量被划分为不同的维度。这样,如果标记有 8 个维度,而我们想使用 3 个头,维度将被划分为 2 个 4 维的数组,每个头将使用其中一个:
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
对于另一个紧凑且高效的实现,您可以在 PyTorch 中使用 torch.nn.MultiheadAttention
类。
tip
ChatGPT 关于为什么将令牌的维度分配给各个头而不是让每个头检查所有令牌的所有维度的简短回答:
尽管允许每个头处理所有嵌入维度似乎是有利的,因为每个头将能够访问完整的信息,但标准做法是 将嵌入维度分配给各个头。这种方法在计算效率与模型性能之间取得了平衡,并鼓励每个头学习多样化的表示。因此,通常更倾向于分割嵌入维度,而不是让每个头检查所有维度。