3. Token Embeddings

Reading time: 8 minutes

Token Embeddings

在对文本数据进行分词后,为像 GPT 这样的训练大型语言模型(LLMs)准备数据的下一个关键步骤是创建 token embeddings。Token embeddings 将离散的标记(如单词或子词)转换为模型可以处理和学习的连续数值向量。此解释分解了 token embeddings、它们的初始化、使用以及位置嵌入在增强模型对标记序列理解中的作用。

tip

这个第三阶段的目标非常简单:为词汇表中每个先前的标记分配一个所需维度的向量以训练模型。 词汇表中的每个单词将在 X 维空间中有一个点。
请注意,最初每个单词在空间中的位置只是“随机”初始化,这些位置是可训练的参数(将在训练过程中改进)。

此外,在 token embedding 过程中 创建了另一层嵌入,它表示(在这种情况下)单词在训练句子中的绝对位置。这样,句子中不同位置的单词将具有不同的表示(含义)。

What Are Token Embeddings?

Token Embeddings 是在连续向量空间中对标记的数值表示。词汇表中的每个标记都与一个固定维度的唯一向量相关联。这些向量捕捉了关于标记的语义和句法信息,使模型能够理解数据中的关系和模式。

  • Vocabulary Size: 模型词汇表中唯一标记的总数(例如,单词、子词)。
  • Embedding Dimensions: 每个标记向量中的数值(维度)数量。更高的维度可以捕捉更细微的信息,但需要更多的计算资源。

Example:

  • Vocabulary Size: 6 tokens [1, 2, 3, 4, 5, 6]
  • Embedding Dimensions: 3 (x, y, z)

Initializing Token Embeddings

在训练开始时,token embeddings 通常用小的随机值初始化。这些初始值在训练过程中进行调整(微调),以更好地表示标记的含义,基于训练数据。

PyTorch Example:

python
import torch

# Set a random seed for reproducibility
torch.manual_seed(123)

# Create an embedding layer with 6 tokens and 3 dimensions
embedding_layer = torch.nn.Embedding(6, 3)

# Display the initial weights (embeddings)
print(embedding_layer.weight)

输出:

lua
luaCopy codeParameter containing:
tensor([[ 0.3374, -0.1778, -0.1690],
[ 0.9178,  1.5810,  1.3010],
[ 1.2753, -0.2010, -0.1606],
[-0.4015,  0.9666, -1.1481],
[-1.1589,  0.3255, -0.6315],
[-2.8400, -0.7849, -1.4096]], requires_grad=True)

解释:

  • 每一行对应词汇表中的一个标记。
  • 每一列表示嵌入向量中的一个维度。
  • 例如,索引为 3 的标记具有嵌入向量 [-0.4015, 0.9666, -1.1481]

访问标记的嵌入:

python
# Retrieve the embedding for the token at index 3
token_index = torch.tensor([3])
print(embedding_layer(token_index))

输出:

lua
tensor([[-0.4015,  0.9666, -1.1481]], grad_fn=<EmbeddingBackward0>)

解释:

  • 索引为 3 的标记由向量 [-0.4015, 0.9666, -1.1481] 表示。
  • 这些值是可训练的参数,模型将在训练过程中调整这些参数,以更好地表示标记的上下文和含义。

标记嵌入在训练中的工作原理

在训练过程中,输入数据中的每个标记都被转换为其对应的嵌入向量。这些向量随后在模型内的各种计算中使用,例如注意力机制和神经网络层。

示例场景:

  • 批量大小: 8(同时处理的样本数量)
  • 最大序列长度: 4(每个样本的标记数量)
  • 嵌入维度: 256

数据结构:

  • 每个批次表示为形状为 (batch_size, max_length, embedding_dim) 的 3D 张量。
  • 对于我们的示例,形状将是 (8, 4, 256)

可视化:

css
cssCopy codeBatch
┌─────────────┐
│ Sample 1    │
│ ┌─────┐     │
│ │Token│ → [x₁₁, x₁₂, ..., x₁₂₅₆]
│ │ 1   │     │
│ │...  │     │
│ │Token│     │
│ │ 4   │     │
│ └─────┘     │
│ Sample 2    │
│ ┌─────┐     │
│ │Token│ → [x₂₁, x₂₂, ..., x₂₂₅₆]
│ │ 1   │     │
│ │...  │     │
│ │Token│     │
│ │ 4   │     │
│ └─────┘     │
│ ...         │
│ Sample 8    │
│ ┌─────┐     │
│ │Token│ → [x₈₁, x₈₂, ..., x₈₂₅₆]
│ │ 1   │     │
│ │...  │     │
│ │Token│     │
│ │ 4   │     │
│ └─────┘     │
└─────────────┘

解释:

  • 序列中的每个令牌由一个256维的向量表示。
  • 模型处理这些嵌入以学习语言模式并生成预测。

位置嵌入:为令牌嵌入添加上下文

虽然令牌嵌入捕捉了单个令牌的含义,但它们并不固有地编码令牌在序列中的位置。理解令牌的顺序对于语言理解至关重要。这就是位置嵌入发挥作用的地方。

为什么需要位置嵌入:

  • 令牌顺序很重要: 在句子中,意义往往依赖于单词的顺序。例如,“猫坐在垫子上”与“垫子坐在猫上”。
  • 嵌入限制: 如果没有位置信息,模型将令牌视为“词袋”,忽略它们的顺序。

位置嵌入的类型:

  1. 绝对位置嵌入:
  • 为序列中的每个位置分配一个唯一的位置向量。
  • 示例: 任何序列中的第一个令牌具有相同的位置嵌入,第二个令牌具有另一个,以此类推。
  • 使用者: OpenAI的GPT模型。
  1. 相对位置嵌入:
  • 编码令牌之间的相对距离,而不是它们的绝对位置。
  • 示例: 指示两个令牌之间的距离,无论它们在序列中的绝对位置如何。
  • 使用者: 像Transformer-XL和一些BERT变体的模型。

位置嵌入是如何集成的:

  • 相同维度: 位置嵌入与令牌嵌入具有相同的维度。
  • 相加: 它们被添加到令牌嵌入中,将令牌身份与位置信息结合,而不增加整体维度。

添加位置嵌入的示例:

假设一个令牌嵌入向量是[0.5, -0.2, 0.1],其位置嵌入向量是[0.1, 0.3, -0.1]。模型使用的组合嵌入将是:

css
Combined Embedding = Token Embedding + Positional Embedding
= [0.5 + 0.1, -0.2 + 0.3, 0.1 + (-0.1)]
= [0.6, 0.1, 0.0]

位置嵌入的好处:

  • 上下文意识: 模型可以根据位置区分标记。
  • 序列理解: 使模型能够理解语法、句法和上下文相关的含义。

代码示例

以下是来自 https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/ch02.ipynb 的代码示例:

python
# Use previous code...

# Create dimensional emdeddings
"""
BPE uses a vocabulary of 50257 words
Let's supose we want to use 256 dimensions (instead of the millions used by LLMs)
"""

vocab_size = 50257
output_dim = 256
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)

## Generate the dataloader like before
max_length = 4
dataloader = create_dataloader_v1(
raw_text, batch_size=8, max_length=max_length,
stride=max_length, shuffle=False
)
data_iter = iter(dataloader)
inputs, targets = next(data_iter)

# Apply embeddings
token_embeddings = token_embedding_layer(inputs)
print(token_embeddings.shape)
torch.Size([8, 4, 256]) # 8 x 4 x 256

# Generate absolute embeddings
context_length = max_length
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

pos_embeddings = pos_embedding_layer(torch.arange(max_length))

input_embeddings = token_embeddings + pos_embeddings
print(input_embeddings.shape) # torch.Size([8, 4, 256])

参考文献