4. 注意メカニズム
Reading time: 19 minutes
ニューラルネットワークにおける注意メカニズムと自己注意
注意メカニズムは、ニューラルネットワークが出力の各部分を生成する際に入力の特定の部分に焦点を当てることを可能にします。これにより、異なる入力に異なる重みが割り当てられ、モデルがタスクに最も関連する入力を決定するのに役立ちます。これは、文全体の文脈を理解することが正確な翻訳に必要な機械翻訳のようなタスクにおいて重要です。
tip
この第4段階の目標は非常にシンプルです:いくつかの注意メカニズムを適用することです。これらは、語彙内の単語と、LLMのトレーニングに使用される現在の文の隣接単語との関係を捉えるための多くの繰り返し層**になります。
これには多くの層が使用されるため、多くの学習可能なパラメータがこの情報を捉えることになります。
注意メカニズムの理解
言語翻訳に使用される従来のシーケンス・ツー・シーケンスモデルでは、モデルは入力シーケンスを固定サイズのコンテキストベクターにエンコードします。しかし、このアプローチは長い文に対しては苦労します。なぜなら、固定サイズのコンテキストベクターが必要な情報をすべて捉えられない可能性があるからです。注意メカニズムは、出力トークンを生成する際にモデルがすべての入力トークンを考慮できるようにすることで、この制限に対処します。
例:機械翻訳
ドイツ語の文「Kannst du mir helfen diesen Satz zu übersetzen」を英語に翻訳することを考えてみましょう。単語ごとの翻訳では、言語間の文法構造の違いにより、文法的に正しい英語の文は生成されません。注意メカニズムは、出力文の各単語を生成する際に、入力文の関連部分に焦点を当てることを可能にし、より正確で一貫した翻訳を実現します。
自己注意の紹介
自己注意、または内部注意は、注意が単一のシーケンス内で適用され、そのシーケンスの表現を計算するメカニズムです。これにより、シーケンス内の各トークンが他のすべてのトークンに注意を向けることができ、トークン間の依存関係を距離に関係なく捉えるのに役立ちます。
重要な概念
- トークン:入力シーケンスの個々の要素(例:文中の単語)。
- 埋め込み:トークンのベクトル表現で、意味情報を捉えます。
- 注意重み:他のトークンに対する各トークンの重要性を決定する値。
注意重みの計算:ステップバイステップの例
文**「Hello shiny sun!」**を考え、各単語を3次元の埋め込みで表現します:
- Hello:
[0.34, 0.22, 0.54]
- shiny:
[0.53, 0.34, 0.98]
- sun:
[0.29, 0.54, 0.93]
私たちの目標は、自己注意を使用して**「shiny」のコンテキストベクター**を計算することです。
ステップ1:注意スコアの計算
tip
各次元のクエリの値を各トークンの関連する値と掛け算し、結果を加算します。トークンのペアごとに1つの値が得られます。
文中の各単語について、「shiny」に対する注意スコアを、その埋め込みのドット積を計算することで求めます。
「Hello」と「shiny」の注意スコア
 (1) (1).png)
「shiny」と「shiny」の注意スコア
 (1) (1) (1) (1) (1) (1) (1).png)
「sun」と「shiny」の注意スコア
 (1) (1) (1) (1).png)
ステップ2:注意スコアを正規化して注意重みを取得
tip
数学用語に迷わないでください。この関数の目標はシンプルです。すべての重みを正規化して合計が1になるようにします。
さらに、softmax関数が使用されるのは、指数部分によって違いを強調し、有用な値を検出しやすくするためです。
注意スコアにsoftmax関数を適用して、合計が1になる注意重みに変換します。
 (1) (1) (1) (1).png)
指数の計算:
 (1) (1).png)
合計の計算:
 (1) (1).png)
注意重みの計算:
 (1) (1).png)
ステップ3:コンテキストベクターの計算
tip
各注意重みを関連するトークンの次元に掛け算し、すべての次元を合計して1つのベクトル(コンテキストベクター)を得ます。
コンテキストベクターは、すべての単語の埋め込みの重み付き合計として計算され、注意重みを使用します。
.png)
各成分の計算:
- 「Hello」の重み付き埋め込み:
 (1) (1).png)
- 「shiny」の重み付き埋め込み:
 (1) (1).png)
- 「sun」の重み付き埋め込み:
 (1) (1).png)
重み付き埋め込みの合計:
context vector=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]
このコンテキストベクターは、「shiny」という単語の強化された埋め込みを表し、文中のすべての単語からの情報を組み込んでいます。
プロセスの要約
- 注意スコアの計算:ターゲット単語の埋め込みとシーケンス内のすべての単語の埋め込みとの間のドット積を使用します。
- スコアを正規化して注意重みを取得:注意スコアにsoftmax関数を適用して、合計が1になる重みを取得します。
- コンテキストベクターの計算:各単語の埋め込みをその注意重みで掛け算し、結果を合計します。
学習可能な重みを持つ自己注意
実際には、自己注意メカニズムは学習可能な重みを使用して、クエリ、キー、および値の最適な表現を学習します。これには、3つの重み行列を導入します:
 (1) (1).png)
クエリは以前と同様に使用するデータであり、キーと値の行列は単なるランダムな学習可能な行列です。
ステップ1:クエリ、キー、および値の計算
各トークンは、定義された行列で次元値を掛け算することによって、独自のクエリ、キー、および値の行列を持ちます:
.png)
これらの行列は、元の埋め込みを注意計算に適した新しい空間に変換します。
例
次のように仮定します:
- 入力次元
din=3
(埋め込みサイズ) - 出力次元
dout=2
(クエリ、キー、および値のための希望する次元)
重み行列を初期化します:
import torch.nn as nn
d_in = 3
d_out = 2
W_query = nn.Parameter(torch.rand(d_in, d_out))
W_key = nn.Parameter(torch.rand(d_in, d_out))
W_value = nn.Parameter(torch.rand(d_in, d_out))
クエリ、キー、値を計算する:
queries = torch.matmul(inputs, W_query)
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)
ステップ 2: スケールドドットプロダクトアテンションの計算
アテンションスコアの計算
以前の例と似ていますが、今回はトークンの次元の値を使用する代わりに、トークンのキー行列を使用します(すでに次元を使用して計算されています)。したがって、各クエリ qi
とキー kj
に対して:
.png)
スコアのスケーリング
ドット積が大きくなりすぎないように、キー次元 dk
の平方根でスケーリングします:
.png)
tip
スコアは次元の平方根で割られます。なぜなら、ドット積が非常に大きくなる可能性があり、これがそれらを調整するのに役立つからです。
ソフトマックスを適用してアテンションウェイトを取得: 初期の例と同様に、すべての値を正規化して合計が1になるようにします。
.png)
ステップ 3: コンテキストベクトルの計算
初期の例と同様に、すべての値行列をそのアテンションウェイトで掛けて合計します:
.png)
コード例
https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb から例を取得すると、私たちが話した自己注意機能を実装するこのクラスを確認できます:
import torch
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
import torch.nn as nn
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
context_vec = attn_weights @ values
return context_vec
d_in=3
d_out=2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))
tip
注意:行列をランダムな値で初期化する代わりに、nn.Linear
を使用してすべての重みをトレーニングするパラメータとしてマークします。
因果注意:未来の単語を隠す
LLMでは、モデルが現在の位置の前に出現するトークンのみを考慮して次のトークンを予測することを望みます。因果注意、またはマスク付き注意は、注意メカニズムを修正して未来のトークンへのアクセスを防ぐことによってこれを実現します。
因果注意マスクの適用
因果注意を実装するために、ソフトマックス操作の前に注意スコアにマスクを適用します。これにより、残りのスコアは合計1になります。このマスクは、未来のトークンの注意スコアを負の無限大に設定し、ソフトマックスの後にその注意重みがゼロになることを保証します。
手順
- 注意スコアの計算:以前と同様。
- マスクの適用:対角線の上に負の無限大で満たされた上三角行列を使用します。
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
masked_scores = attention_scores + mask
- ソフトマックスの適用:マスクされたスコアを使用して注意重みを計算します。
attention_weights = torch.softmax(masked_scores, dim=-1)
ドロップアウトによる追加の注意重みのマスキング
過学習を防ぐために、ソフトマックス操作の後に注意重みにドロップアウトを適用できます。ドロップアウトは、トレーニング中に注意重みの一部をランダムにゼロにします。
dropout = nn.Dropout(p=0.5)
attention_weights = dropout(attention_weights)
通常のドロップアウトは約10-20%です。
コード例
コード例はhttps://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynbからです:
import torch
import torch.nn as nn
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token
keys = self.W_key(x) # This generates the keys of the tokens
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
シングルヘッドアテンションをマルチヘッドアテンションに拡張する
マルチヘッドアテンションは、実際には複数のインスタンスの自己アテンション関数を実行し、それぞれが独自の重みを持つことで、異なる最終ベクトルが計算されることを意味します。
コード例
前のコードを再利用し、ラッパーを追加して何度も実行することも可能ですが、これはすべてのヘッドを同時に処理する最適化されたバージョンです(高価なforループの数を減らします)。コードに示されているように、各トークンの次元はヘッドの数に応じて異なる次元に分割されます。このように、トークンが8次元を持ち、3つのヘッドを使用したい場合、次元は4次元の2つの配列に分割され、各ヘッドはそのうちの1つを使用します。
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)
別のコンパクトで効率的な実装には、PyTorchのtorch.nn.MultiheadAttention
クラスを使用できます。
tip
ChatGPTの短い回答:なぜトークンの次元をヘッド間で分割する方が良いのか、各ヘッドがすべてのトークンのすべての次元をチェックするのではなく:
各ヘッドがすべての埋め込み次元を処理できるようにすることは、各ヘッドが完全な情報にアクセスできるため有利に思えるかもしれませんが、標準的な実践は埋め込み次元をヘッド間で分割することです。このアプローチは、計算効率とモデルのパフォーマンスのバランスを取り、各ヘッドが多様な表現を学ぶことを促します。したがって、埋め込み次元を分割することは、各ヘッドがすべての次元をチェックするよりも一般的に好まれます。