4. Attention Mechanisms

Tip

AWS ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ:HackTricks Training AWS Red Team Expert (ARTE)
GCP ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training GCP Red Team Expert (GRTE) Azure ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training Azure Red Team Expert (AzRTE)

HackTricks ์ง€์›ํ•˜๊ธฐ

Attention Mechanisms and Self-Attention in Neural Networks

Attention mechanisms allow neural networks to focus on specific parts of the input when generating each part of the output. They assign different weights to different inputs, helping the model decide which inputs are most relevant to the task at hand. This is crucial in tasks like machine translation, where understanding the context of the entire sentence is necessary for accurate translation.

Tip

์ด ๋„ค ๋ฒˆ์งธ ๋‹จ๊ณ„์˜ ๋ชฉํ‘œ๋Š” ๋งค์šฐ ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค: ์ผ๋ถ€ ์ฃผ์˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์ ์šฉํ•˜์„ธ์š”. ์ด๋Š” ์–ดํœ˜์˜ ๋‹จ์–ด์™€ ํ˜„์žฌ LLM ํ›ˆ๋ จ์— ์‚ฌ์šฉ๋˜๋Š” ๋ฌธ์žฅ์—์„œ์˜ ์ด์›ƒ ๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ํฌ์ฐฉํ•˜๋Š” ๋งŽ์€ ๋ฐ˜๋ณต ๋ ˆ์ด์–ด๊ฐ€ ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค.
์ด๋ฅผ ์œ„ํ•ด ๋งŽ์€ ๋ ˆ์ด์–ด๊ฐ€ ์‚ฌ์šฉ๋˜๋ฏ€๋กœ ๋งŽ์€ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜๊ฐ€ ์ด ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

Understanding Attention Mechanisms

In traditional sequence-to-sequence models used for language translation, the model encodes an input sequence into a fixed-size context vector. However, this approach struggles with long sentences because the fixed-size context vector may not capture all necessary information. Attention mechanisms address this limitation by allowing the model to consider all input tokens when generating each output token.

Example: Machine Translation

Consider translating the German sentence โ€œKannst du mir helfen diesen Satz zu รผbersetzenโ€ into English. A word-by-word translation would not produce a grammatically correct English sentence due to differences in grammatical structures between languages. An attention mechanism enables the model to focus on relevant parts of the input sentence when generating each word of the output sentence, leading to a more accurate and coherent translation.

Introduction to Self-Attention

Self-attention, or intra-attention, is a mechanism where attention is applied within a single sequence to compute a representation of that sequence. It allows each token in the sequence to attend to all other tokens, helping the model capture dependencies between tokens regardless of their distance in the sequence.

Key Concepts

  • Tokens: ์ž…๋ ฅ ์‹œํ€€์Šค์˜ ๊ฐœ๋ณ„ ์š”์†Œ (์˜ˆ: ๋ฌธ์žฅ์˜ ๋‹จ์–ด).
  • Embeddings: ์˜๋ฏธ ์ •๋ณด๋ฅผ ํฌ์ฐฉํ•˜๋Š” ํ† ํฐ์˜ ๋ฒกํ„ฐ ํ‘œํ˜„.
  • Attention Weights: ๋‹ค๋ฅธ ํ† ํฐ์— ๋Œ€ํ•œ ๊ฐ ํ† ํฐ์˜ ์ค‘์š”์„ฑ์„ ๊ฒฐ์ •ํ•˜๋Š” ๊ฐ’.

Calculating Attention Weights: A Step-by-Step Example

Letโ€™s consider the sentence โ€œHello shiny sun!โ€ and represent each word with a 3-dimensional embedding:

  • Hello: [0.34, 0.22, 0.54]
  • shiny: [0.53, 0.34, 0.98]
  • sun: [0.29, 0.54, 0.93]

Our goal is to compute the context vector for the word โ€œshinyโ€ using self-attention.

Step 1: Compute Attention Scores

Tip

๊ฐ ์ฐจ์› ๊ฐ’์„ ์ฟผ๋ฆฌ์™€ ๊ฐ ํ† ํฐ์˜ ๊ด€๋ จ ๊ฐ’๊ณผ ๊ณฑํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ๋”ํ•˜์„ธ์š”. ๊ฐ ํ† ํฐ ์Œ์— ๋Œ€ํ•ด 1๊ฐœ์˜ ๊ฐ’์„ ์–ป์Šต๋‹ˆ๋‹ค.

For each word in the sentence, compute the attention score with respect to โ€œshinyโ€ by calculating the dot product of their embeddings.

Attention Score between โ€œHelloโ€ and โ€œshinyโ€

Attention Score between โ€œshinyโ€ and โ€œshinyโ€

Attention Score between โ€œsunโ€ and โ€œshinyโ€

Step 2: Normalize Attention Scores to Obtain Attention Weights

Tip

์ˆ˜ํ•™์  ์šฉ์–ด์— ํœ˜๋ง๋ฆฌ์ง€ ๋งˆ์„ธ์š”, ์ด ํ•จ์ˆ˜์˜ ๋ชฉํ‘œ๋Š” ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋“  ๊ฐ€์ค‘์น˜๋ฅผ ์ •๊ทœํ™”ํ•˜์—ฌ ์ดํ•ฉ์ด 1์ด ๋˜๋„๋ก ํ•˜์„ธ์š”.
๋˜ํ•œ, softmax ํ•จ์ˆ˜๋Š” ์ง€์ˆ˜ ๋ถ€๋ถ„์œผ๋กœ ์ธํ•ด ์ฐจ์ด๋ฅผ ๊ฐ•์กฐํ•˜๋ฏ€๋กœ ์œ ์šฉํ•œ ๊ฐ’์„ ๊ฐ์ง€ํ•˜๊ธฐ ์‰ฝ๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

Apply the softmax function to the attention scores to convert them into attention weights that sum to 1.

Calculating the exponentials:

Calculating the sum:

Calculating attention weights:

Step 3: Compute the Context Vector

Tip

๊ฐ ์ฃผ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ ธ์™€ ๊ด€๋ จ๋œ ํ† ํฐ ์ฐจ์›์— ๊ณฑํ•œ ๋‹ค์Œ ๋ชจ๋“  ์ฐจ์›์„ ๋”ํ•˜์—ฌ ๋‹จ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ(์ปจํ…์ŠคํŠธ ๋ฒกํ„ฐ)๋ฅผ ์–ป์œผ์„ธ์š”.

The context vector is computed as the weighted sum of the embeddings of all words, using the attention weights.

Calculating each component:

  • Weighted Embedding of โ€œHelloโ€:
  • Weighted Embedding of โ€œshinyโ€:
  • Weighted Embedding of โ€œsunโ€:

Summing the weighted embeddings:

context vector=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]

์ด ์ปจํ…์ŠคํŠธ ๋ฒกํ„ฐ๋Š” โ€œshinyโ€œ๋ผ๋Š” ๋‹จ์–ด์— ๋Œ€ํ•œ ํ’๋ถ€ํ•œ ์ž„๋ฒ ๋”ฉ์„ ๋‚˜ํƒ€๋‚ด๋ฉฐ, ๋ฌธ์žฅ์˜ ๋ชจ๋“  ๋‹จ์–ด๋กœ๋ถ€ํ„ฐ ์ •๋ณด๋ฅผ ํ†ตํ•ฉํ•ฉ๋‹ˆ๋‹ค.

Summary of the Process

  1. Compute Attention Scores: Use the dot product between the embedding of the target word and the embeddings of all words in the sequence.
  2. Normalize Scores to Get Attention Weights: Apply the softmax function to the attention scores to obtain weights that sum to 1.
  3. Compute Context Vector: Multiply each wordโ€™s embedding by its attention weight and sum the results.

Self-Attention with Trainable Weights

In practice, self-attention mechanisms use trainable weights to learn the best representations for queries, keys, and values. This involves introducing three weight matrices:

The query is the data to use like before, while the keys and values matrices are just random-trainable matrices.

Step 1: Compute Queries, Keys, and Values

Each token will have its own query, key and value matrix by multiplying its dimension values by the defined matrices:

These matrices transform the original embeddings into a new space suitable for computing attention.

Example

Assuming:

  • Input dimension din=3 (embedding size)
  • Output dimension dout=2 (desired dimension for queries, keys, and values)

Initialize the weight matrices:

import torch.nn as nn

d_in = 3
d_out = 2

W_query = nn.Parameter(torch.rand(d_in, d_out))
W_key = nn.Parameter(torch.rand(d_in, d_out))
W_value = nn.Parameter(torch.rand(d_in, d_out))

์ฟผ๋ฆฌ, ํ‚ค, ๊ฐ’ ๊ณ„์‚ฐ:

queries = torch.matmul(inputs, W_query)
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)

Step 2: Compute Scaled Dot-Product Attention

Compute Attention Scores

์ด์ „ ์˜ˆ์ œ์™€ ์œ ์‚ฌํ•˜์ง€๋งŒ, ์ด๋ฒˆ์—๋Š” ํ† ํฐ์˜ ์ฐจ์› ๊ฐ’ ๋Œ€์‹  ํ† ํฐ์˜ ํ‚ค ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค(์ด๋ฏธ ์ฐจ์›์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ„์‚ฐ๋จ). ๋”ฐ๋ผ์„œ ๊ฐ ์ฟผ๋ฆฌ qiโ€‹์™€ ํ‚ค kjโ€‹์— ๋Œ€ํ•ด:

Scale the Scores

๋‚ด์ ์ด ๋„ˆ๋ฌด ์ปค์ง€๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด, ํ‚ค ์ฐจ์› dkโ€‹์˜ ์ œ๊ณฑ๊ทผ์œผ๋กœ ์ ์ˆ˜๋ฅผ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค:

Tip

์ ์ˆ˜๋Š” ์ฐจ์›์˜ ์ œ๊ณฑ๊ทผ์œผ๋กœ ๋‚˜๋ˆ„์–ด์ง€๋Š”๋ฐ, ์ด๋Š” ๋‚ด์ ์ด ๋งค์šฐ ์ปค์งˆ ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ด๋ฅผ ์กฐ์ ˆํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.

Apply Softmax to Obtain Attention Weights: ์ดˆ๊ธฐ ์˜ˆ์ œ์™€ ๊ฐ™์ด ๋ชจ๋“  ๊ฐ’์„ ์ •๊ทœํ™”ํ•˜์—ฌ ํ•ฉ์ด 1์ด ๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

Step 3: Compute Context Vectors

์ดˆ๊ธฐ ์˜ˆ์ œ์™€ ๊ฐ™์ด, ๊ฐ ๊ฐ’์„ ์ฃผ์˜ ๊ฐ€์ค‘์น˜๋กœ ๊ณฑํ•˜์—ฌ ๋ชจ๋“  ๊ฐ’ ํ–‰๋ ฌ์„ ํ•ฉ์‚ฐํ•ฉ๋‹ˆ๋‹ค:

Code Example

https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb์—์„œ ์˜ˆ์ œ๋ฅผ ๊ฐ€์ ธ์™€์„œ ์šฐ๋ฆฌ๊ฐ€ ์ด์•ผ๊ธฐํ•œ ์ž๊ธฐ ์ฃผ์˜ ๊ธฐ๋Šฅ์„ ๊ตฌํ˜„ํ•˜๋Š” ์ด ํด๋ž˜์Šค๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

import torch

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

import torch.nn as nn
class SelfAttention_v2(nn.Module):

def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

context_vec = attn_weights @ values
return context_vec

d_in=3
d_out=2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

Tip

๋งคํŠธ๋ฆญ์Šค๋ฅผ ์ž„์˜์˜ ๊ฐ’์œผ๋กœ ์ดˆ๊ธฐํ™”ํ•˜๋Š” ๋Œ€์‹ , nn.Linear๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋“  ๊ฐ€์ค‘์น˜๋ฅผ ํ•™์Šตํ•  ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ํ‘œ์‹œํ•ฉ๋‹ˆ๋‹ค.

์ธ๊ณผ์  ์ฃผ์˜: ๋ฏธ๋ž˜ ๋‹จ์–ด ์ˆจ๊ธฐ๊ธฐ

LLM์—์„œ๋Š” ๋ชจ๋ธ์ด ํ˜„์žฌ ์œ„์น˜ ์ด์ „์— ๋‚˜ํƒ€๋‚˜๋Š” ํ† ํฐ๋งŒ ๊ณ ๋ คํ•˜์—ฌ ๋‹ค์Œ ํ† ํฐ์„ ์˜ˆ์ธกํ•˜๋„๋ก ํ•˜๊ธฐ๋ฅผ ์›ํ•ฉ๋‹ˆ๋‹ค. ์ธ๊ณผ์  ์ฃผ์˜๋Š” ๋งˆ์Šคํ‚น๋œ ์ฃผ์˜๋ผ๊ณ ๋„ ํ•˜๋ฉฐ, ์ฃผ์˜ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์ˆ˜์ •ํ•˜์—ฌ ๋ฏธ๋ž˜ ํ† ํฐ์— ๋Œ€ํ•œ ์ ‘๊ทผ์„ ๋ฐฉ์ง€ํ•จ์œผ๋กœ์จ ์ด๋ฅผ ๋‹ฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์ธ๊ณผ์  ์ฃผ์˜ ๋งˆ์Šคํฌ ์ ์šฉ

์ธ๊ณผ์  ์ฃผ์˜๋ฅผ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด, ์šฐ๋ฆฌ๋Š” ์†Œํ”„ํŠธ๋งฅ์Šค ์—ฐ์‚ฐ ์ด์ „์— ์ฃผ์˜ ์ ์ˆ˜์— ๋งˆ์Šคํฌ๋ฅผ ์ ์šฉํ•˜์—ฌ ๋‚˜๋จธ์ง€ ์ ์ˆ˜๊ฐ€ ์—ฌ์ „ํžˆ 1์ด ๋˜๋„๋ก ํ•ฉ๋‹ˆ๋‹ค. ์ด ๋งˆ์Šคํฌ๋Š” ๋ฏธ๋ž˜ ํ† ํฐ์˜ ์ฃผ์˜ ์ ์ˆ˜๋ฅผ ์Œ์˜ ๋ฌดํ•œ๋Œ€๋กœ ์„ค์ •ํ•˜์—ฌ ์†Œํ”„ํŠธ๋งฅ์Šค ์ดํ›„์— ๊ทธ๋“ค์˜ ์ฃผ์˜ ๊ฐ€์ค‘์น˜๊ฐ€ 0์ด ๋˜๋„๋ก ๋ณด์žฅํ•ฉ๋‹ˆ๋‹ค.

๋‹จ๊ณ„

  1. ์ฃผ์˜ ์ ์ˆ˜ ๊ณ„์‚ฐ: ์ด์ „๊ณผ ๋™์ผํ•ฉ๋‹ˆ๋‹ค.
  2. ๋งˆ์Šคํฌ ์ ์šฉ: ๋Œ€๊ฐ์„  ์œ„์— ์Œ์˜ ๋ฌดํ•œ๋Œ€๋กœ ์ฑ„์›Œ์ง„ ์ƒ์‚ผ๊ฐ ํ–‰๋ ฌ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
masked_scores = attention_scores + mask
  1. ์†Œํ”„ํŠธ๋งฅ์Šค ์ ์šฉ: ๋งˆ์Šคํ‚น๋œ ์ ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ฃผ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
attention_weights = torch.softmax(masked_scores, dim=-1)

๋“œ๋กญ์•„์›ƒ์œผ๋กœ ์ถ”๊ฐ€ ์ฃผ์˜ ๊ฐ€์ค‘์น˜ ๋งˆ์Šคํ‚น

๊ณผ์ ํ•ฉ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด, ์†Œํ”„ํŠธ๋งฅ์Šค ์—ฐ์‚ฐ ํ›„ ์ฃผ์˜ ๊ฐ€์ค‘์น˜์— ๋“œ๋กญ์•„์›ƒ์„ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋“œ๋กญ์•„์›ƒ์€ ํ•™์Šต ์ค‘์— ์ผ๋ถ€ ์ฃผ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๋ฌด์ž‘์œ„๋กœ 0์œผ๋กœ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

dropout = nn.Dropout(p=0.5)
attention_weights = dropout(attention_weights)

์ •์ƒ์ ์ธ ๋“œ๋กญ์•„์›ƒ ๋น„์œจ์€ ์•ฝ 10-20%์ž…๋‹ˆ๋‹ค.

์ฝ”๋“œ ์˜ˆ์ œ

์ฝ”๋“œ ์˜ˆ์ œ๋Š” https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb์—์„œ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค:

import torch
import torch.nn as nn

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

class CausalAttention(nn.Module):

def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # This generates the keys of the tokens
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
attn_scores.masked_fill_(  # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)

context_vec = attn_weights @ values
return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Single-Head Attention์„ Multi-Head Attention์œผ๋กœ ํ™•์žฅํ•˜๊ธฐ

Multi-head attention์€ ์‹ค์งˆ์ ์œผ๋กœ ์ž๊ธฐ ์ฃผ์˜ ํ•จ์ˆ˜์˜ ์—ฌ๋Ÿฌ ์ธ์Šคํ„ด์Šค๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ๊ตฌ์„ฑ๋˜๋ฉฐ, ๊ฐ ์ธ์Šคํ„ด์Šค๋Š” ์ž์‹ ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์–ด ์„œ๋กœ ๋‹ค๋ฅธ ์ตœ์ข… ๋ฒกํ„ฐ๊ฐ€ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค.

์ฝ”๋“œ ์˜ˆ์ œ

์ด์ „ ์ฝ”๋“œ๋ฅผ ์žฌ์‚ฌ์šฉํ•˜๊ณ  ์—ฌ๋Ÿฌ ๋ฒˆ ์‹คํ–‰ํ•˜๋Š” ๋ž˜ํผ๋ฅผ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ์ด๋Š” ๋ชจ๋“  ํ—ค๋“œ๋ฅผ ๋™์‹œ์— ์ฒ˜๋ฆฌํ•˜๋Š” https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb์—์„œ ๋” ์ตœ์ ํ™”๋œ ๋ฒ„์ „์ž…๋‹ˆ๋‹ค (๋น„์šฉ์ด ๋งŽ์ด ๋“œ๋Š” for ๋ฃจํ”„์˜ ์ˆ˜๋ฅผ ์ค„์ž„). ์ฝ”๋“œ์—์„œ ๋ณผ ์ˆ˜ ์žˆ๋“ฏ์ด ๊ฐ ํ† ํฐ์˜ ์ฐจ์›์€ ํ—ค๋“œ ์ˆ˜์— ๋”ฐ๋ผ ์„œ๋กœ ๋‹ค๋ฅธ ์ฐจ์›์œผ๋กœ ๋‚˜๋‰ฉ๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ํ† ํฐ์ด 8์ฐจ์›์„ ๊ฐ€์ง€๊ณ  ์žˆ๊ณ  3๊ฐœ์˜ ํ—ค๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•  ๊ฒฝ์šฐ, ์ฐจ์›์€ 4์ฐจ์›์˜ 2๊ฐœ์˜ ๋ฐฐ์—ด๋กœ ๋‚˜๋‰˜๊ณ  ๊ฐ ํ—ค๋“œ๋Š” ๊ทธ ์ค‘ ํ•˜๋‚˜๋ฅผ ์‚ฌ์šฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค:

class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

๋‹ค๋ฅธ ๊ฐ„๊ฒฐํ•˜๊ณ  ํšจ์œจ์ ์ธ ๊ตฌํ˜„์„ ์œ„ํ•ด PyTorch์˜ torch.nn.MultiheadAttention ํด๋ž˜์Šค๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Tip

ChatGPT์˜ ์งง์€ ๋‹ต๋ณ€: ์™œ ๊ฐ ํ—ค๋“œ๊ฐ€ ๋ชจ๋“  ํ† ํฐ์˜ ๋ชจ๋“  ์ฐจ์›์„ ํ™•์ธํ•˜๋Š” ๋Œ€์‹  ํ† ํฐ์˜ ์ฐจ์›์„ ํ—ค๋“œ ๊ฐ„์— ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด ๋” ๋‚˜์€์ง€์— ๋Œ€ํ•œ ์„ค๋ช…:

๊ฐ ํ—ค๋“œ๊ฐ€ ๋ชจ๋“  ์ž„๋ฒ ๋”ฉ ์ฐจ์›์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด ์œ ๋ฆฌํ•ด ๋ณด์ผ ์ˆ˜ ์žˆ์ง€๋งŒ, ํ‘œ์ค€ ๊ด€ํ–‰์€ ์ž„๋ฒ ๋”ฉ ์ฐจ์›์„ ํ—ค๋“œ ๊ฐ„์— ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ์ ‘๊ทผ ๋ฐฉ์‹์€ ๊ณ„์‚ฐ ํšจ์œจ์„ฑ๊ณผ ๋ชจ๋ธ ์„ฑ๋Šฅ์˜ ๊ท ํ˜•์„ ๋งž์ถ”๊ณ  ๊ฐ ํ—ค๋“œ๊ฐ€ ๋‹ค์–‘ํ•œ ํ‘œํ˜„์„ ํ•™์Šตํ•˜๋„๋ก ์žฅ๋ คํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ž„๋ฒ ๋”ฉ ์ฐจ์›์„ ๋‚˜๋ˆ„๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์œผ๋กœ ๊ฐ ํ—ค๋“œ๊ฐ€ ๋ชจ๋“  ์ฐจ์›์„ ํ™•์ธํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์„ ํ˜ธ๋ฉ๋‹ˆ๋‹ค.

References

Tip

AWS ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ:HackTricks Training AWS Red Team Expert (ARTE)
GCP ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training GCP Red Team Expert (GRTE) Azure ํ•ดํ‚น ๋ฐฐ์šฐ๊ธฐ ๋ฐ ์—ฐ์Šตํ•˜๊ธฐ: HackTricks Training Azure Red Team Expert (AzRTE)

HackTricks ์ง€์›ํ•˜๊ธฐ