4. Mecanismos de Atención

Mecanismos de Atención y Autoatención en Redes Neuronales

Los mecanismos de atención permiten que las redes neuronales se enfoquen en partes específicas de la entrada al generar cada parte de la salida. Asignan diferentes pesos a diferentes entradas, ayudando al modelo a decidir cuáles entradas son más relevantes para la tarea en cuestión. Esto es crucial en tareas como la traducción automática, donde entender el contexto de toda la oración es necesario para una traducción precisa.

tip

El objetivo de esta cuarta fase es muy simple: Aplicar algunos mecanismos de atención. Estos van a ser muchas capas repetidas que van a capturar la relación de una palabra en el vocabulario con sus vecinos en la oración actual que se está utilizando para entrenar el LLM.
Se utilizan muchas capas para esto, por lo que muchos parámetros entrenables van a estar capturando esta información.

Entendiendo los Mecanismos de Atención

En los modelos tradicionales de secuencia a secuencia utilizados para la traducción de lenguajes, el modelo codifica una secuencia de entrada en un vector de contexto de tamaño fijo. Sin embargo, este enfoque tiene dificultades con oraciones largas porque el vector de contexto de tamaño fijo puede no capturar toda la información necesaria. Los mecanismos de atención abordan esta limitación al permitir que el modelo considere todos los tokens de entrada al generar cada token de salida.

Ejemplo: Traducción Automática

Considera traducir la oración en alemán "Kannst du mir helfen diesen Satz zu übersetzen" al inglés. Una traducción palabra por palabra no produciría una oración en inglés gramaticalmente correcta debido a las diferencias en las estructuras gramaticales entre los idiomas. Un mecanismo de atención permite que el modelo se enfoque en partes relevantes de la oración de entrada al generar cada palabra de la oración de salida, lo que lleva a una traducción más precisa y coherente.

Introducción a la Autoatención

La autoatención, o intra-atención, es un mecanismo donde la atención se aplica dentro de una única secuencia para calcular una representación de esa secuencia. Permite que cada token en la secuencia asista a todos los demás tokens, ayudando al modelo a capturar dependencias entre tokens sin importar su distancia en la secuencia.

Conceptos Clave

  • Tokens: Elementos individuales de la secuencia de entrada (por ejemplo, palabras en una oración).
  • Embeddings: Representaciones vectoriales de tokens, capturando información semántica.
  • Pesos de Atención: Valores que determinan la importancia de cada token en relación con los demás.

Cálculo de Pesos de Atención: Un Ejemplo Paso a Paso

Consideremos la oración "Hello shiny sun!" y representemos cada palabra con un embedding de 3 dimensiones:

  • Hello: [0.34, 0.22, 0.54]
  • shiny: [0.53, 0.34, 0.98]
  • sun: [0.29, 0.54, 0.93]

Nuestro objetivo es calcular el vector de contexto para la palabra "shiny" utilizando autoatención.

Paso 1: Calcular Puntuaciones de Atención

tip

Simplemente multiplica cada valor de dimensión de la consulta con el correspondiente de cada token y suma los resultados. Obtienes 1 valor por par de tokens.

Para cada palabra en la oración, calcula la puntuación de atención con respecto a "shiny" calculando el producto punto de sus embeddings.

Puntuación de Atención entre "Hello" y "shiny"

Puntuación de Atención entre "shiny" y "shiny"

Puntuación de Atención entre "sun" y "shiny"

Paso 2: Normalizar Puntuaciones de Atención para Obtener Pesos de Atención

tip

No te pierdas en los términos matemáticos, el objetivo de esta función es simple, normalizar todos los pesos para que suman 1 en total.

Además, se utiliza la función softmax porque acentúa las diferencias debido a la parte exponencial, facilitando la detección de valores útiles.

Aplica la función softmax a las puntuaciones de atención para convertirlas en pesos de atención que sumen 1.

Calculando los exponentes:

Calculando la suma:

Calculando los pesos de atención:

Paso 3: Calcular el Vector de Contexto

tip

Simplemente toma cada peso de atención y multiplícalo por las dimensiones del token relacionado y luego suma todas las dimensiones para obtener solo 1 vector (el vector de contexto)

El vector de contexto se calcula como la suma ponderada de los embeddings de todas las palabras, utilizando los pesos de atención.

Calculando cada componente:

  • Embedding Ponderado de "Hello":
  • Embedding Ponderado de "shiny":
  • Embedding Ponderado de "sun":

Sumando los embeddings ponderados:

vector de contexto=[0.0779+0.2156+0.1057, 0.0504+0.1382+0.1972, 0.1237+0.3983+0.3390]=[0.3992,0.3858,0.8610]

Este vector de contexto representa el embedding enriquecido para la palabra "shiny", incorporando información de todas las palabras en la oración.

Resumen del Proceso

  1. Calcular Puntuaciones de Atención: Utiliza el producto punto entre el embedding de la palabra objetivo y los embeddings de todas las palabras en la secuencia.
  2. Normalizar Puntuaciones para Obtener Pesos de Atención: Aplica la función softmax a las puntuaciones de atención para obtener pesos que sumen 1.
  3. Calcular el Vector de Contexto: Multiplica el embedding de cada palabra por su peso de atención y suma los resultados.

Autoatención con Pesos Entrenables

En la práctica, los mecanismos de autoatención utilizan pesos entrenables para aprender las mejores representaciones para consultas, claves y valores. Esto implica introducir tres matrices de peso:

La consulta es la información a utilizar como antes, mientras que las matrices de claves y valores son solo matrices aleatorias entrenables.

Paso 1: Calcular Consultas, Claves y Valores

Cada token tendrá su propia matriz de consulta, clave y valor multiplicando sus valores de dimensión por las matrices definidas:

Estas matrices transforman los embeddings originales en un nuevo espacio adecuado para calcular la atención.

Ejemplo

Suponiendo:

  • Dimensión de entrada din=3 (tamaño del embedding)
  • Dimensión de salida dout=2 (dimensión deseada para consultas, claves y valores)

Inicializa las matrices de peso:

python
import torch.nn as nn

d_in = 3
d_out = 2

W_query = nn.Parameter(torch.rand(d_in, d_out))
W_key = nn.Parameter(torch.rand(d_in, d_out))
W_value = nn.Parameter(torch.rand(d_in, d_out))

Calcular consultas, claves y valores:

python
queries = torch.matmul(inputs, W_query)
keys = torch.matmul(inputs, W_key)
values = torch.matmul(inputs, W_value)

Paso 2: Calcular la Atención de Producto Escalar

Calcular Puntuaciones de Atención

Similar al ejemplo anterior, pero esta vez, en lugar de usar los valores de las dimensiones de los tokens, usamos la matriz de claves del token (ya calculada usando las dimensiones):. Así que, para cada consulta qi​ y clave kj​:

Escalar las Puntuaciones

Para evitar que los productos punto se vuelvan demasiado grandes, escálalos por la raíz cuadrada de la dimensión de la clave dk​:

tip

La puntuación se divide por la raíz cuadrada de las dimensiones porque los productos punto pueden volverse muy grandes y esto ayuda a regularlos.

Aplicar Softmax para Obtener Pesos de Atención: Al igual que en el ejemplo inicial, normaliza todos los valores para que sumen 1.

Paso 3: Calcular Vectores de Contexto

Al igual que en el ejemplo inicial, simplemente suma todas las matrices de valores multiplicando cada una por su peso de atención:

Ejemplo de Código

Tomando un ejemplo de https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb puedes revisar esta clase que implementa la funcionalidad de auto-atención de la que hablamos:

python
import torch

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

import torch.nn as nn
class SelfAttention_v2(nn.Module):

def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

def forward(self, x):
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

context_vec = attn_weights @ values
return context_vec

d_in=3
d_out=2
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

note

Tenga en cuenta que en lugar de inicializar las matrices con valores aleatorios, se utiliza nn.Linear para marcar todos los pesos como parámetros a entrenar.

Atención Causal: Ocultando Palabras Futuras

Para los LLMs, queremos que el modelo considere solo los tokens que aparecen antes de la posición actual para predecir el siguiente token. La atención causal, también conocida como atención enmascarada, logra esto modificando el mecanismo de atención para prevenir el acceso a tokens futuros.

Aplicando una Máscara de Atención Causal

Para implementar la atención causal, aplicamos una máscara a las puntuaciones de atención antes de la operación softmax para que las restantes sumen 1. Esta máscara establece las puntuaciones de atención de los tokens futuros en negativo infinito, asegurando que después del softmax, sus pesos de atención sean cero.

Pasos

  1. Calcular Puntuaciones de Atención: Igual que antes.
  2. Aplicar Máscara: Utilizar una matriz triangular superior llena de negativo infinito por encima de la diagonal.
python
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
masked_scores = attention_scores + mask
  1. Aplicar Softmax: Calcular los pesos de atención utilizando las puntuaciones enmascaradas.
python
attention_weights = torch.softmax(masked_scores, dim=-1)

Enmascarando Pesos de Atención Adicionales con Dropout

Para prevenir el sobreajuste, podemos aplicar dropout a los pesos de atención después de la operación softmax. El dropout anula aleatoriamente algunos de los pesos de atención durante el entrenamiento.

python
dropout = nn.Dropout(p=0.5)
attention_weights = dropout(attention_weights)

Un abandono regular es de aproximadamente 10-20%.

Ejemplo de Código

Ejemplo de código de https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb:

python
import torch
import torch.nn as nn

inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your     (x^1)
[0.55, 0.87, 0.66], # journey  (x^2)
[0.57, 0.85, 0.64], # starts   (x^3)
[0.22, 0.58, 0.33], # with     (x^4)
[0.77, 0.25, 0.10], # one      (x^5)
[0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

class CausalAttention(nn.Module):

def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # This generates the keys of the tokens
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.transpose(1, 2) # Moves the third dimension to the second one and the second one to the third one to be able to multiply
attn_scores.masked_fill_(  # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)

context_vec = attn_weights @ values
return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
d_in = 3
d_out = 2
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Ampliando la Atención de Cabeza Única a Atención de Múltiples Cabezas

La atención de múltiples cabezas en términos prácticos consiste en ejecutar múltiples instancias de la función de autoatención, cada una con sus propios pesos, de modo que se calculen diferentes vectores finales.

Ejemplo de Código

Podría ser posible reutilizar el código anterior y simplemente agregar un envoltorio que lo ejecute varias veces, pero esta es una versión más optimizada de https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb que procesa todas las cabezas al mismo tiempo (reduciendo el número de costosos bucles for). Como puedes ver en el código, las dimensiones de cada token se dividen en diferentes dimensiones de acuerdo con el número de cabezas. De esta manera, si un token tiene 8 dimensiones y queremos usar 3 cabezas, las dimensiones se dividirán en 2 arreglos de 4 dimensiones y cada cabeza usará uno de ellos:

python
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)

def forward(self, x):
b, num_tokens, d_in = x.shape
# b is the num of batches
# num_tokens is the number of tokens per batch
# d_in is the dimensions er token

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

Para otra implementación compacta y eficiente, podrías usar la clase torch.nn.MultiheadAttention en PyTorch.

tip

Respuesta corta de ChatGPT sobre por qué es mejor dividir las dimensiones de los tokens entre las cabezas en lugar de que cada cabeza verifique todas las dimensiones de todos los tokens:

Si bien permitir que cada cabeza procese todas las dimensiones de embedding podría parecer ventajoso porque cada cabeza tendría acceso a toda la información, la práctica estándar es dividir las dimensiones de embedding entre las cabezas. Este enfoque equilibra la eficiencia computacional con el rendimiento del modelo y fomenta que cada cabeza aprenda representaciones diversas. Por lo tanto, dividir las dimensiones de embedding se prefiere generalmente sobre permitir que cada cabeza verifique todas las dimensiones.

References