4. Attention Mechanisms

Reading time: 18 minutes

tip

AWSハッキングを学び、実践する:HackTricks Training AWS Red Team Expert (ARTE)
GCPハッキングを学び、実践する:HackTricks Training GCP Red Team Expert (GRTE) Azureハッキングを学び、実践する:HackTricks Training Azure Red Team Expert (AzRTE)

HackTricksをサポートする

Attention Mechanisms and Self-Attention in Neural Networks

Attention mechanisms allow neural networks to focus on specific parts of the input when generating each part of the output. They assign different weights to different inputs, helping the model decide which inputs are most relevant to the task at hand. This is crucial in tasks like machine translation, where understanding the context of the entire sentence is necessary for accurate translation.

tip

この第4段階の目標は非常にシンプルです: いくつかの注意メカニズムを適用することです。これらは、語彙内の単語と現在の文の隣接語との関係を捉えるための多くの 繰り返し層になります。
これには多くの層が使用されるため、多くの学習可能なパラメータがこの情報を捉えることになります。

Understanding Attention Mechanisms

In traditional sequence-to-sequence models used for language translation, the model encodes an input sequence into a fixed-size context vector. However, this approach struggles with long sentences because the fixed-size context vector may not capture all necessary information. Attention mechanisms address this limitation by allowing the model to consider all input tokens when generating each output token.

Example: Machine Translation

Consider translating the German sentence "Kannst du mir helfen diesen Satz zu übersetzen" into English. A word-by-word translation would not produce a grammatically correct English sentence due to differences in grammatical structures between languages. An attention mechanism enables the model to focus on relevant parts of the input sentence when generating each word of the output sentence, leading to a more accurate and coherent translation.

Introduction to Self-Attention

Self-attention, or intra-attention, is a mechanism where attention is applied within a single sequence to compute a representation of that sequence. It allows each token in the sequence to attend to all other tokens, helping the model capture dependencies between tokens regardless of their distance in the sequence.

Key Concepts

  • Tokens: 入力シーケンスの個々の要素(例: 文中の単語)。
  • Embeddings: トークンのベクトル表現で、意味情報を捉えます。
  • Attention Weights: 他のトークンに対する各トークンの重要性を決定する値。

Calculating Attention Weights: A Step-by-Step Example

Let's consider the sentence "Hello shiny sun!" and represent each word with a 3-dimensional embedding:

  • Hello: [0.34, 0.22, 0.54]
  • shiny: [0.53, 0.34, 0.98]
  • sun: [0.29, 0.54, 0.93]

Our goal is to compute the context vector for the word "shiny" using self-attention.

Step 1: Compute Attention Scores

tip

各次元のクエリの値を各トークンの関連する値と掛け算し、結果を加算します。トークンのペアごとに1つの値が得られます。

For each word in the sentence, compute the attention score with respect to "shiny" by calculating the dot product of their embeddings.

Attention Score between "Hello" and "shiny"

Attention Score between "shiny" and "shiny"

Attention Score between "sun" and "shiny"

Step 2: Normalize Attention Scores to Obtain Attention Weights

tip

数学用語に迷わないでください。この関数の目標はシンプルです。すべての重みを正規化して合計が1になるようにします
さらに、softmax関数が使用されるのは、指数部分によって違いを強調し、有用な値を検出しやすくするためです。

Apply the softmax function to the attention scores to convert them into attention weights that sum to 1.

Calculating the exponentials:

Calculating the sum:

Calculating attention weights:

Step 3: Compute the Context Vector

tip

各注意重みを関連するトークンの次元に掛け算し、すべての次元を合計して1つのベクトル(コンテキストベクトル)を得ます。

The context vector is computed as the weighted sum of the embeddings of all words, using the attention weights.

Calculating each component:

  • Weighted Embedding of "Hello":
  • Weighted Embedding of "shiny":
  • Weighted Embedding of "sun":

Summing the weighted embeddings:

context vector=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]

このコンテキストベクトルは、文中のすべての単語からの情報を取り入れた「shiny」という単語の強化された埋め込みを表します。

Summary of the Process

  1. Compute Attention Scores: Use the dot product between the embedding of the target word and the embeddings of all words in the sequence.
  2. Normalize Scores to Get Attention Weights: Apply the softmax function to the attention scores to obtain weights that sum to 1.
  3. Compute Context Vector: Multiply each word's embedding by its attention weight and sum the results.

Self-Attention with Trainable Weights

In practice, self-attention mechanisms use trainable weights to learn the best representations for queries, keys, and values. This involves introducing three weight matrices:

The query is the data to use like before, while the keys and values matrices are just random-trainable matrices.

Step 1: Compute Queries, Keys, and Values

Each token will have its own query, key and value matrix by multiplying its dimension values by the defined matrices:

These matrices transform the original embeddings into a new space suitable for computing attention.

Example

Assuming:

  • Input dimension din=3 (embedding size)
  • Output dimension dout=2 (desired dimension for queries, keys, and values)

Initialize the weight matrices:

python
import torch.nn as nn

d_in = 3
d_out = 2

W_query = nn.Parameter(torch.rand(d_in, d_out))
W_key = nn.Parameter(torch.rand(d_in, d_out))
W_value = nn.Parameter(torch.rand(d_in, d_out))

クエリ、キー、値を計算する:

python
queries = torch.matmul(inputs, W_query)
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)

ステップ 2: スケールドドットプロダクトアテンションの計算

アテンションスコアの計算

以前の例と似ていますが、今回はトークンの次元の値を使用する代わりに、トークンのキー行列を使用します(すでに次元を使用して計算されています)。したがって、各クエリ qi​ とキー kj​ に対して:

スコアのスケーリング

ドット積が大きくなりすぎないように、キー次元 dk​ の平方根でスケーリングします:

tip

スコアは次元の平方根で割られます。なぜなら、ドット積が非常に大きくなる可能性があり、これがそれを調整するのに役立つからです。

ソフトマックスを適用してアテンションウェイトを取得: 初期の例と同様に、すべての値を正規化して合計が1になるようにします。

ステップ 3: コンテキストベクトルの計算

初期の例と同様に、すべての値行列をそのアテンションウェイトで掛けて合計します:

コード例

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb からの例を取り上げると、私たちが話した自己注意機能を実装するこのクラスを確認できます:

python
import torch

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

import torch.nn as nn
class SelfAttention_v2(nn.Module):

def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

context_vec = attn_weights @ values
return context_vec

d_in=3
d_out=2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tip

注意:行列をランダムな値で初期化する代わりに、nn.Linearを使用してすべての重みをトレーニングするパラメータとしてマークします。

因果注意:未来の単語を隠す

LLMでは、モデルが現在の位置の前に出現するトークンのみを考慮して次のトークンを予測することを望みます。因果注意、またはマスク付き注意は、注意メカニズムを修正して未来のトークンへのアクセスを防ぐことによってこれを実現します。

因果注意マスクの適用

因果注意を実装するために、ソフトマックス操作の前に注意スコアにマスクを適用します。これにより、残りのスコアは合計1になります。このマスクは、未来のトークンの注意スコアを負の無限大に設定し、ソフトマックスの後にその注意重みがゼロになることを保証します。

手順

  1. 注意スコアの計算:以前と同様。
  2. マスクの適用:対角線の上に負の無限大で満たされた上三角行列を使用します。
python
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
masked_scores = attention_scores + mask
  1. ソフトマックスの適用:マスクされたスコアを使用して注意重みを計算します。
python
attention_weights = torch.softmax(masked_scores, dim=-1)

ドロップアウトによる追加の注意重みのマスキング

過学習を防ぐために、ソフトマックス操作の後に注意重みにドロップアウトを適用できます。ドロップアウトは、トレーニング中に注意重みの一部をランダムにゼロにします

python
dropout = nn.Dropout(p=0.5)
attention_weights = dropout(attention_weights)

通常のドロップアウトは約10-20%です。

コード例

コード例はhttps://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynbからです:

python
import torch
import torch.nn as nn

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

class CausalAttention(nn.Module):

def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # This generates the keys of the tokens
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
attn_scores.masked_fill_(  # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)

context_vec = attn_weights @ values
return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

シングルヘッドアテンションをマルチヘッドアテンションに拡張する

マルチヘッドアテンションは、実際には複数のインスタンスの自己注意機能を実行し、それぞれが独自の重みを持つことで、異なる最終ベクトルが計算されることを意味します。

コード例

前のコードを再利用し、ラッパーを追加して何度も実行することも可能ですが、これはすべてのヘッドを同時に処理する最適化されたバージョンです(高価なforループの数を減らします)。コードに示されているように、各トークンの次元はヘッドの数に応じて異なる次元に分割されます。このように、トークンが8次元を持ち、3つのヘッドを使用したい場合、次元は4次元の2つの配列に分割され、各ヘッドはそのうちの1つを使用します:

python
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

別のコンパクトで効率的な実装には、PyTorchのtorch.nn.MultiheadAttentionクラスを使用できます。

tip

ChatGPTの短い回答:なぜトークンの次元をヘッド間で分割する方が良いのか、各ヘッドがすべてのトークンのすべての次元をチェックするのではなく:

各ヘッドがすべての埋め込み次元を処理できるようにすることは、各ヘッドが完全な情報にアクセスできるため有利に思えるかもしれませんが、標準的な実践は埋め込み次元をヘッド間で分割することです。このアプローチは、計算効率とモデルのパフォーマンスのバランスを取り、各ヘッドが多様な表現を学ぶことを促します。したがって、埋め込み次元を分割することは、各ヘッドがすべての次元をチェックするよりも一般的に好まれます。

References

tip

AWSハッキングを学び、実践する:HackTricks Training AWS Red Team Expert (ARTE)
GCPハッキングを学び、実践する:HackTricks Training GCP Red Team Expert (GRTE) Azureハッキングを学び、実践する:HackTricks Training Azure Red Team Expert (AzRTE)

HackTricksをサポートする