Transformer架构详解
Zhongjun Qiu 元婴开发者

本章我们将介绍如何搭建一个完整的 Transformer 模型。

Encoder-Decoder

在 Transformer 中,注意力机制主要被应用在两个核心组件中:Encoder(编码器)和 Decoder(解码器)。事实上,后续基于 Transformer 的各类预训练语言模型,大多都是在 Encoder 和 Decoder 结构上进行改进,例如只使用 Encoder 的 BERT,以及只使用 Decoder 的 GPT 等。

Seq2Seq 模型

Seq2Seq(Sequence-to-Sequence),即序列到序列模型,是自然语言处理中的经典任务类型。其核心思想是:输入一个自然语言序列 input = (x1, x2, x3, ..., xn) 模型输出另一个自然语言序列 output = (y1, y2, y3, ..., ym) 输入和输出的长度一般并不相同。

几乎所有 NLP 任务都可以形式化为 Seq2Seq 问题。例如:

  • 文本分类任务:输出序列长度为 1(如 m = 1)。
  • 词性标注任务:输出序列与输入等长(如 m = n)。

机器翻译是最典型的 Seq2Seq 任务。例如输入中文句子“今天天气真好”,输出为英文句子 “Today is a good day.”。 Transformer 就是一个典型的 Seq2Seq 模型,它最初正是为了解决机器翻译任务而提出的。

Seq2Seq 模型的基本思路是: 1. 编码(Encoding):将输入序列映射为内部语义表示向量; 2. 解码(Decoding):将内部语义表示转换为目标语言序列。

在 Transformer 中,Encoder 负责编码,Decoder 负责解码。输入序列经 Encoder 编码后,其输出结果会传递给每一层 Decoder,经过解码后生成目标序列。

接下来,我们将介绍 Encoder 与 Decoder 中的重要组成部分:前馈神经网络(FNN)层归一化(Layer Norm)残差连接(Residual Connection),再进一步分析它们的整体结构。

前馈神经网络(Feed Forward Neural Network)

前馈神经网络(FNN)是一种典型的全连接神经网络结构,其中每一层的神经元与上下层的所有神经元相连。 在 Transformer 的每个 Layer 中,都包含一个注意力机制模块和一个前馈神经网络模块。

下面给出一个简化的 PyTorch 实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class FNN(nn.Module):
"""
Feed Forward Neural Network Module

Args:
dim: int
特征维度(token 表示的维数)
hidden_dim: int
隐藏层维度(默认 4 * dim)
dropout: float
dropout 比例(默认 0.1)
"""
def __init__(self, dim: int, hidden_dim: int = 0, dropout: float = 0.1):
super().__init__()
if hidden_dim == 0:
hidden_dim = 4 * dim
self.fc = nn.Linear(dim, hidden_dim, bias=False)
self.proj = nn.Linear(hidden_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.activation = nn.GELU()

def forward(self, x: Tensor) -> Tensor:
x = self.fc(x)
x = self.activation(x)
x = self.proj(x)
x = self.dropout(x)
return x

在 Transformer 中,前馈网络通常由两个线性层和一个激活函数(如 ReLU 或 GELU)组成,同时加入 Dropout 来防止过拟合。

FNN 层的核心作用主要有两个:

1. 提供非线性变换

这是 FNN 最根本的作用。

Transformer 的架构(无论是 Encoder 还是 Decoder)都是由一层一层的模块堆叠而成的。在 FNN 之前,注意力机制(MHA)子层的计算(尤其是值的加权求和)本质上是线性的

如果我们移除 FNN,那么整个 Transformer 模块(MHA + 残差连接)就几乎等同于一系列的线性变换。无论你堆叠多少层线性变换,其效果等同于一层线性变换,这会使得深度模型失去“深度”的意义,其表达能力将大打折扣。

FNN 通过引入 ReLUGELU 这样的非线性激活函数,打破了这种线性。它允许模型学习输入和输出之间更复杂、更非线性的关系。没有 FNN,Transformer 就无法成为一个真正的深度学习模型。

2. 特征转换与信息提炼(逐位置)

FNN 的另一个关键特性是它是 Position-wise(逐位置)的。

这意味着 FNN 会独立地对序列中的每一个 Token(在 dim 维度上)进行相同的非线性变换,它不会在 Token 之间(seq_len 维度上)混合信息。

这实现了一种巧妙的“分工”:

  1. 注意力层(MHA):负责跨序列(Token-Mixing)。它让每个 Token 去“看”并“吸收”序列中其他 Token 的信息,完成上下文的聚合。
  2. 前馈层(FNN):负责逐位置(Channel-Mixing)。它接收 MHA 聚合来的信息,然后在每个 Token 自己的“小厨房”里进行深入加工、提炼和非线性转换。

FNN 的“升维-降维”结构

FNN Linear(ReLU(Linear(x))) 的结构通常是先升维再降维(例如 dim → 4 × dim → dim):

  • 升维 (Expansion):第一个线性层将 Token 特征从 dim 维(如 512)投影到一个更高维的空间(如 2048)。这为模型提供了更广阔的“思考空间”,以便检测和学习更复杂的特征组合。
  • 非线性激活 (Activation)ReLUGELU 在这个高维空间中进行筛选和变换。
  • 降维 (Projection):第二个线性层再将这些在高维空间中被提炼过的特征“压缩”回原始的 dim 维度,以便输入到下一个 Transformer 模块中。

如果说 MHA 的任务是 “从邻居那里收集信息”,那么 FNN 的核心作用就是 “关起门来消化吸收这些信息”。它通过非线性变换来提炼每个 Token 的特征,赋予模型学习复杂函数的能力,是 Transformer 架构中不可或缺的计算和表达核心。

层归一化(Layer Normalization)

归一化是深度学习中提升训练稳定性的重要手段。其核心思想是: 让每层网络的输入保持相对稳定的分布,从而加快收敛速度、减轻梯度爆炸或消失问题。

常见的归一化方式包括:

  • Batch Normalization(批归一化)
  • Layer Normalization(层归一化)

1. Batch Norm 的局限性

Batch Norm 在一个 mini-batch 上计算均值与方差: 然后将特征标准化为:

但在 NLP 中,这种做法存在以下问题:

  • BatchNorm 在 batch 维度上统计,batch 较小时统计不稳定;
  • 不同句子、不同长度、padding token 导致 batch 内分布不一致;
  • 在变长序列任务中,Batch Norm 的统计量难以统一;
  • 每步都要保存均值方差,增加计算负担。

2. Layer Norm 的改进思路

Layer Norm 不在 batch 维度上统计,而是在每个样本内部计算所有特征的均值和方差,从而独立地标准化每个 token:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class LayerNorm(nn.Module):
"""
Layer Normalization Module

Args:
dim: int
the number of features used to represent a token
eps: float
a small value to prevent division by zero (default: 1e-5)
"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))

def forward(self, x: Tensor) -> Tensor:
mean = x.mean(-1, keepdim=True)
variance = ((x - mean) ** 2).mean(-1, keepdim=True)
x_norm = (x - mean) / torch.sqrt(variance + self.eps)
return self.gamma * x_norm + self.beta

这种方式与 Batch Norm 的公式类似,但统计范围不同;

而 LayerNorm 只对单个 token 的特征做归一化,完全不受 batch 大小和句子长度影响。

LayerNorm 中的仿射变换(γ 与 β)

在归一化后,输入被标准化为均值 0、方差 1:

虽然这能稳定训练,但也会削弱模型的表达能力。 为此,引入了可学习的仿射变换参数y = γ ⋅  + β

其中:

  • γ(gamma) 控制缩放(scale);
  • β(beta) 控制平移(shift)。

这样,模型既能享受归一化带来的数值稳定,又能通过学习恢复最优的分布,从而保持表达能力。

nn.Linear 的区别

对比项 LayerNorm 仿射变换 nn.Linear
操作对象 每个特征独立缩放和平移 所有特征线性组合
参数形状 γ, β ∈ ℝᵈ W ∈ ℝ^{d_out×d_in}, b ∈ ℝ^{d_out}
是否混合特征维度
是否改变输出维度
主要目的 保持数值稳定 + 恢复表达能力 特征映射与维度变换

简而言之,LayerNorm 的仿射变换不是为了学习新特征,而是为了让归一化后的特征重新获得灵活的表达能力

残差连接(Residual Connection)

由于 Transformer 模型结构较深、层数较多,直接堆叠多个非线性变换会导致模型训练困难,出现梯度消失或退化问题。为了解决这一问题,Transformer 借鉴了 ResNet 的思想,在每个子层中引入残差连接。

残差连接的核心思想是:下一层的输入不仅仅是上一层的输出,还直接包含上一层的输入,从而允许底层信息直接流向高层。

这样一来,高层网络只需要学习输入与输出之间的残差(Residual),使得训练更稳定、更高效。

在 Encoder 中,残差连接的实现非常典型:每个子层(例如多头注意力层或前馈神经网络层)都在输入前进行 LayerNorm 归一化,并在输出后与原输入相加: x = x + MultiHeadSelfAttention(LayerNorm(x))

output = x + FNN(LayerNorm(x))

在代码中,可以这样实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
"""
x: (B, S, dim)
B 是 batch size
S 是序列长度
dim 是每个 token 的特征维度
"""
# 多头注意力层(带残差)
ix = self.ln1(x)
x = x + self.mha(ix, ix, ix, mask)
# 前馈层(带残差)
x = x + self.fnn(self.ln2(x))
return x

Pre-Norm 与 Post-Norm

在引入 LayerNorm 时,不同的 Transformer 实现会有两种归一化方式:Post-NormPre-Norm

归一化方式 公式(简化) 描述
Post-Norm(原版 Transformer) xout = LayerNorm(x + SubLayer(x)) 原论文版本:先执行子层计算,再加上残差,最后进行归一化。
Pre-Norm(现代 LLM 主流) xout = x + SubLayer(LayerNorm(x)) 现代版本:先归一化,再执行子层计算,最后加上残差。

为什么现代模型更倾向于使用 Pre-Norm?

(1)稳定梯度,防止训练初期爆炸

在 Post-Norm 结构中,若子层输出较大,累积的残差会在深层网络中指数放大,导致激活值和梯度爆炸,训练容易发散。 而 Pre-Norm 结构在子层计算前就对输入进行标准化,使输入始终处于稳定分布(均值 0、方差 1)范围内,显著缓解梯度爆炸问题,提高了训练稳定性。

(2)保留恒等映射路径(Identity Mapping)

Pre-Norm 结构有助于模型更容易学习到恒等映射:当子层参数初始化为 0 时,SubLayer(LayerNorm(x)) ≈ 0,因此: xout ≈ x 这意味着在训练初期,网络能自然地保留输入信息并逐步学习到更复杂的表示,不会破坏梯度传播路径,利于深层模型收敛。

Encoder 模块实现

在完成多头注意力层、前馈网络、层归一化和残差连接之后,我们可以搭建 Transformer 的 Encoder。 一个完整的 Encoder 由 NEncoder Block 堆叠而成,每个 Block 包含:

  • 一层多头自注意力(Self-Attention)
  • 一层前馈神经网络(FNN)
  • 两个 LayerNorm 层
  • 两个残差连接
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class EncoderBlock(nn.Module):
"""
Transformer Encoder Block consisting of MHA and FNN modules.

Args:
args: ModelArgs
model arguments
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.mha = MHA(args)
self.fnn = FNN(args.dim, dropout=args.dropout)
self.ln1 = LayerNorm(args.dim)
self.ln2 = LayerNorm(args.dim)

def forward(self, x: Tensor, mask: Tensor) -> Tensor:
"""
x: (B, S, dim)
B is the batch size
S is the sequence length, i.e., the number of tokens
dim is the number of features used to represent a token
"""
# MHA with residual connection and layer normalization
# first layer norm then MHA
ix = self.ln1(x)
x = x + self.mha(ix, ix, ix, mask)
# FNN with residual connection and layer normalization
# first layer norm then FNN
x = x + self.fnn(self.ln2(x))
return x

Decoder 模块实现

Decoder 的结构与 Encoder 类似,但更加复杂。它在原有的前馈层(Feed Forward)和残差连接(Residual Connection)基础上,引入了两种不同的注意力机制: 一是对目标序列自身的自注意力(Masked Self-Attention),二是对源序列的交叉注意力(Cross-Attention),从而实现编码端与解码端的信息交互。

  1. Masked Self-Attention(掩码自注意力) 这一层让 Decoder 在生成第 t 个 token 时,只能看到当前及之前的 token,而无法访问未来信息。 这种约束通过 look-ahead mask 实现,是保证 Transformer 自回归生成(auto-regressive generation)能力的关键。

  2. Cross-Attention(交叉注意力) 第二个注意力层接收 Encoder 的输出作为 keyvalue,接收 Decoder 的输出作为 query。 这样,Decoder 就能“聚焦”源序列中与当前目标 token 相关的信息,实现翻译、摘要等任务中常说的“对源句的关注”。

  3. Feed Forward Layer(前馈层) 对经过注意力机制处理后的每个 token 进行非线性变换与特征增强,提高模型的表达能力。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class DecoderBlock(nn.Module):
"""
Transformer Decoder Block consisting of two MHA modules and one FNN module.

Args:
args: ModelArgs
model arguments
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.self_mha = MHA(args)
self.cross_mha = MHA(args)
self.fnn = FNN(args.dim, dropout=args.dropout)
self.ln1 = LayerNorm(args.dim)
self.ln2 = LayerNorm(args.dim)
self.ln3 = LayerNorm(args.dim)

def forward(self, x: Tensor, enc_out: Tensor, src_mask: Tensor, tgt_mask: Tensor) -> Tensor:
"""
x: (B, S, dim)
B is the batch size
S is the sequence length (number of tokens)
dim is the embedding dimension
enc_out: (B, S, dim)
Encoder output (memory)
"""
# masked self-attention (带因果掩码)
ix = self.ln1(x)
x = x + self.self_mha(ix, ix, ix, tgt_mask)

# cross-attention (与 encoder 输出交互)
ix = self.ln2(x)
x = x + self.cross_mha(ix, enc_out, enc_out, src_mask)

# feed-forward layer
x = x + self.fnn(self.ln3(x))
return x

① Self-Attention 与 Encoder 的区别

在 Encoder 中,自注意力可以访问整个输入序列,因此不需要掩码; 而在 Decoder 中,必须使用 look-ahead mask,以阻止模型看到未来 token。 举例来说,在生成句子 “I love AI and Science” 的过程中,生成 “AI” 时只能看到 “I love”,而不能看到 “AI” 自身之后的词。

这种掩码通过在注意力分数矩阵中屏蔽未来位置(即将其设为负无穷)来实现,使得 softmax 后的权重为 0。

② Cross-Attention 的输入和掩码

cross_mhakeyvalue 均来自 Encoder 的输出,而 query 来自 Decoder 的上一层输出。 在这里传入的 src_maskpadding mask,用于屏蔽输入中补齐的 <pad> token。

由于每个 batch 内的序列长度需要统一,短句会被填充 <pad>,这些 token 不能参与注意力计算,否则会污染上下文。因此通过掩码机制,模型在计算注意力时“看不见”这些填充位置。

下面是经过润色后的版本,保持了与你之前章节一致的学术风格与层次逻辑,同时增强了连贯性与表达的流畅度👇


搭建一个 Transformer

在之前,我们分别深入剖析了 Attention 机制以及 Transformer 的核心组成 —— Encoder 与 Decoder 的结构。 有了这些基础组件,我们终于可以开始搭建一个完整的 Transformer 模型。

Embedding 层

在 NLP 任务中,模型无法直接理解自然语言文本,因此必须先将其转化为机器能够处理的向量表示。 承担这一任务的正是 Embedding 层(词嵌入层)。

Embedding 的本质

Embedding 层本质上是一个固定大小的 可训练查找表(lookup table),用于将离散的 token 映射为连续的向量表示。 在进入神经网络之前,自然语言输入会首先经过 分词器(Tokenizer) 处理,将文本切分成离散的 token,并转化为对应的整数索引(index)。

例如,假设词表大小为 4,输入句子为“我喜欢你”,则分词器可能将输入转化为:

1
2
3
input: 我     → 0
input: 喜欢 → 1
input: 你 → 2

Embedding 层接收的输入通常是一个形状为 (batch_size, seq_len) 的矩阵, 其中 batch_size 表示一次批处理的样本数量,seq_len 表示每个样本的序列长度。 以单个批次为例,输入可能为:

1
[[0, 1, 2]]

其对应的 Embedding 层输出则是一个 (batch_size, seq_len, embedding_dim) 的张量,其中每个整数索引都会映射到一行维度为 embedding_dim 的向量。

PyTorch 已经提供了高效的 nn.Embedding 实现,我们可以直接定义如下:

1
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)

其中:

  • vocab_size 表示词表大小;
  • dim 表示每个 token 的嵌入向量维度。

Embedding 层的权重是可训练的参数,模型会在训练过程中不断更新这些向量,使得语义相近的词在高维空间中也能靠得更近。

Tokenizer 与 Transformer 的关系

💡 一个常见误区是:Transformer 模型会“自动”完成分词。 实际上,Transformer 并不进行分词,它只接收经过 tokenizer 处理后的整数 token 序列。

1️⃣ Transformer 不做分词

Transformer(如 GPT、BERT、T5 等)仅接受 token id(整数序列)作为输入。 原始文本需要通过 Tokenizer 转换为 token id,才能输入模型。 换句话说,模型的第一层 Embedding 只接收整数索引,而不是自然语言字符。

2️⃣ Tokenizer 是独立训练的模块

Tokenizer 通常在模型训练前单独训练,其目的是确定词表与切分规则。 常见方法如下:

方法 原理 特点
Word-level 基于空格或词典拆词 OOV 严重
Subword(BPE、WordPiece) 将词拆分为常见子词单元 兼顾泛化性与词法完整性
Character-level 以字符为单位 粒度过细,训练慢
SentencePiece(T5 / LLaMA 常用) 无需预分词的子词算法 语言无关,适用于多语言模型

典型模型对应的分词方式如下:

模型 使用的分词算法
BERT WordPiece
GPT 系列 Byte-Pair Encoding (BPE)
T5 SentencePiece
LLaMA SentencePiece + Byte Fallback

位置编码(Positional Encoding)

注意力机制天生具备并行计算能力,但它也带来了一个问题:序列的位置信息会丢失。 在 RNN 或 LSTM 中,序列是按时间步递归处理的,因此模型天然保留了词序信息。然而,在自注意力机制中,每个 token 对所有位置的 token 都是平等对待的——换句话说,“我喜欢你”和“你喜欢我”在注意力机制看来几乎没有区别。这显然与自然语言的顺序敏感性相悖。

为了弥补这一点,Transformer 引入了 位置编码(Positional Encoding),将每个 token 的位置信息编码到其词向量中,从而让模型感知序列的顺序。

正余弦位置编码(Sinusoidal PE)

Transformer 原论文中使用的是绝对位置编码,通过正弦和余弦函数计算得到: 其中:

  • pos 表示 token 在序列中的位置;
  • 2i2i+1 分别对应向量中的偶数和奇数维度;
  • d_model 是词向量的维度。

通过奇偶维度分别使用正弦和余弦函数,模型可以编码不同频率的位置信息,从而捕获相对和绝对位置信息

该编码的优势

  1. 支持长度超出训练集的序列 由于位置编码是基于函数计算的,即便测试序列比训练集更长,也能生成对应的 PE 向量。例如,训练集最长句子长度为 20,而测试句子长度为 21,公式依然可以计算出第 21 个 token 的位置向量。

  2. 便于计算相对位置 对于固定长度间距 kPE(pos+k) 可以通过 PE(pos) 线性组合得到。这是因为: 这一性质使得模型能够自然地捕获 token 之间的相对距离。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class PositionalEncoding(nn.Module):
def __init__(self, dim, max_len=1024):
super().__init__()
# 位置索引 (max_len, 1)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 除数项 (dim/2,)
div_term = torch.exp(torch.arange(0, dim, 2).float() * -(math.log(10000.0) / dim))

# 初始化位置编码矩阵 (1, max_len, dim)
pe = torch.zeros(1, max_len, dim)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)

# 注册为 buffer,不作为可训练参数
self.register_buffer('pe', pe)

def forward(self, x):
"""
x: (batch_size, seq_len, dim)
"""
pe = cast(Tensor, self.pe)
x = x + pe[:, :x.size(1), :]
return x

一个完整的 Transformer

将上述所有组件(Embedding、位置编码、EncoderBlock、DecoderBlock、LayerNorm 等)按照 Transformer 的结构拼接起来,我们就得到了一个完整的模型。

本实现采用了Pre-Norm结构(先 LayerNorm 再进入子层),并在 Encoder 和 Decoder 堆叠的最后分别增加了一个 ln_encln_dec,这是现代实现中保持训练稳定的常见做法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class Transformer(nn.Module):
"""
Transformer Model consisting of multiple Encoder and Decoder Blocks.

Args:
args: ModelArgs
model arguments
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.src_embedding = nn.Embedding(args.src_vocab_size, args.dim)
self.tgt_embedding = nn.Embedding(args.tgt_vocab_size, args.dim)
self.position_encoding = PositionalEncoding(args.dim, args.max_seq_len)
self.dropout = nn.Dropout(args.dropout)
self.encoder_blocks = nn.ModuleList(
[EncoderBlock(args) for _ in range(args.num_layer)]
)
self.decoder_blocks = nn.ModuleList(
[DecoderBlock(args) for _ in range(args.num_layer)]
)
self.ln_enc = LayerNorm(args.dim)
self.ln_dec = LayerNorm(args.dim)
self.fc_out = nn.Linear(args.dim, args.tgt_vocab_size, bias=False)

# share the source and target embeddings
if args.src_tgt_emb_shared:
if args.src_vocab_size != args.tgt_vocab_size:
raise ValueError("src_vocab_size must be equal to tgt_vocab_size when sharing embeddings")
self.tgt_embedding.weight = self.src_embedding.weight
# decoder output layer shares weights with target embedding
self.fc_out.weight = self.tgt_embedding.weight

# initialize parameters
self._reset_parameters()

def _reset_parameters(self):
nn.init.normal_(self.src_embedding.weight, mean=0, std=self.src_embedding.embedding_dim ** -0.5)
nn.init.normal_(self.tgt_embedding.weight, mean=0, std=self.tgt_embedding.embedding_dim ** -0.5)

for name, p in self.named_parameters():
if p.dim() > 1 and "embedding" not in name and "fc_out" not in name:
nn.init.xavier_uniform_(p)

def get_num_params(self, non_embedding=False):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= (self.src_embedding.weight.numel() + self.tgt_embedding.weight.numel())
return n_params

def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor) -> Tensor:
"""
src: (B, S_src)
input to the encoder
tgt: (B, S_tgt)
input to the decoder
src_mask: (B, 1, 1, S_src)
mask for the encoder input
tgt_mask: (B, 1, S_tgt, S_tgt)
mask for the decoder input
"""
# Embedding and Positional Encoding
# in original Transformer, the embedding is scaled by sqrt(dim), which is helpful for model convergence
# (B, S_src) -> (B, S_src, dim)
src = self.src_embedding(src) * math.sqrt(self.src_embedding.embedding_dim)
src = self.position_encoding(src)
src = self.dropout(src)

tgt = self.tgt_embedding(tgt) * math.sqrt(self.tgt_embedding.embedding_dim)
tgt = self.position_encoding(tgt)
tgt = self.dropout(tgt)

# Encoder
# (B, S_src, dim)
enc_out = src
for block in self.encoder_blocks:
enc_out = block(enc_out, src_mask)
enc_out: Tensor = self.ln_enc(enc_out)

# Decoder
# (B, S_tgt, dim)
dec_out = tgt
for block in self.decoder_blocks:
dec_out = block(dec_out, enc_out, src_mask, tgt_mask)
dec_out: Tensor = self.ln_dec(dec_out)

# Output layer
# (B, S_tgt, dim) -> (B, S_tgt, tgt_vocab_size)
dec_out = self.fc_out(dec_out)

return dec_out

def get_padding_mask(seq: Tensor, pad_idx: int) -> Tensor:
"""
Get the padding mask for a sequence.

Args:
seq: (B, S)
input sequence
pad_idx: int
index of the padding token
"""
# (B, 1, 1, S)
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask

def get_causal_mask(size: int) -> Tensor:
"""
Get the causal mask for a sequence.

Args:
size: int
length of the sequence
"""
# (1, 1, S, S)
mask = torch.tril(torch.ones(1, 1, size, size))
return mask

if __name__ == "__main__":
args = ModelArgs()
model = Transformer(args)
print(f"Number of parameters: {model.get_num_params()/1e9:.3f}B")
src = torch.randint(0, args.src_vocab_size, (2, 10))
tgt = torch.randint(0, args.tgt_vocab_size, (2, 10))
src_mask = get_padding_mask(src, args.src_pad_idx)
tgt_mask = get_padding_mask(tgt, args.tgt_pad_idx) * get_causal_mask(tgt.size(1))
out: Tensor = model(src, tgt, src_mask, tgt_mask)
print(out.size()) # should be (2, 10, tgt_vocab_size)

1. 权重共享 (Weight Sharing) 的作用

代码中有两处权重共享:

self.tgt_embedding.weight = self.src_embedding.weight

  • 作用:让源语言的嵌入层和目标语言的嵌入层共享同一套权重
  • 何时使用:这通常用于词表高度重合的任务。例如,机器翻译(如中译英)一般不共享,因为词表完全不同。但对于单语种任务(如文本摘要、英译英改写)或源/目标语言非常相似(如西班牙语译葡萄牙语)且使用了共享词表 (BPE/SentencePiece) 时,共享权重可以显著减少参数量,并让模型学到更泛化的 Token 表示。

self.fc_out.weight = self.tgt_embedding.weight

  • 作用:让目标嵌入层 (Tgt Embedding) 和最终的输出线性层 (Output fc_out) 共享权重。
  • 这是为什么
    • 逻辑自洽tgt_embedding 的作用是 (Token ID Vector),即把一个词的 ID 映射到 dim 维空间。fc_out 的作用是 (Vector Logits),即把 dim 维空间的一个向量映射回词表中每个词的得分。
    • 这两个操作在逻辑上是互逆的。如果一个词(如 “cat”)在嵌入空间中由向量 vcat 表示,那么当模型在最后一步生成了向量 vcat 时,它应该给 “cat” 这个词打高分。
    • 通过共享权重,fc_out (形状 [tgt_vocab_size, dim]) 实际上就是 tgt_embedding (形状 [tgt_vocab_size, dim]) 的转置(在 PyTorch 中 nn.Linear 内部会自动处理)。
  • 好处:这被证明是一种非常有效的参数绑定 (Parameter Tying) 技巧,它大幅减少了模型的参数量fc_out 层通常是模型中最大的层之一),并有助于模型更快地收敛和提高性能。

2. 参数初始化 (Initialization) 的原因

Embedding 层:nn.init.normal_(…, std=self.dim ** -0.5)

  • 原因:这与 forward 函数中的 * math.sqrt(self.dim) 缩放配套使用
  • 在原版 Transformer 论文中,作者建议将 Embedding 层的输出乘以 (即 sqrt(dim))。
  • 为了使 Embedding 层的初始输出(即 embedding * sqrt(dim))的方差接近 1(这是保持训练稳定的理想状态),我们需要让原始 embedding 权重的方差为
  • nn.init.normal_(..., std=self.dim ** -0.5) 正是设置了标准差 ,从而使初始方差 σ2 = 1/dim
  • 总结:这是一种精细的初始化技巧,旨在确保送入 Encoder/Decoder 第一个子层的数据方差适中 (接近 1),防止训练初期梯度过大或过小。

其他层 (MHA, FNN 的线性层):nn.init.xavier_uniform_(p)

  • 原因:这是深度学习中最标准的初始化方法之一,也称为 Glorot 初始化
  • 目标:它试图使信息(激活值)和梯度在网络中前向和反向传播时,方差保持不变
  • 如果不做初始化(使用默认初始化),网络越深,激活值和梯度就越容易在逐层传递中指数级放大(梯度爆炸)或缩小(梯度消失)。
  • Xavier 初始化根据层的输入和输出维度来计算一个合适的随机范围,确保了训练初期的数值稳定性,让模型“冷启动”更加顺畅。

3. Masking 掩码函数

get_padding_mask(seq, pad_idx)

  • 目的:告诉注意力机制(MHA)忽略输入序列中所有<pad>(填充)Token。
  • 原理
    1. seq != pad_idx 会生成一个布尔矩阵 (B, S),<pad> 位置是 False,其他位置是 True
    2. unsqueeze(1).unsqueeze(2) 将其变为 (B, 1, 1, S)
  • 为什么是这个形状:MHA 中的注意力分数矩阵 attn_scores 形状是 (B, H, S_q, S_k) (H=Heads, S_q=Query序列, S_k=Key序列)。
  • 当这个 (B, 1, 1, S_k) 的掩码与 attn_scores 相加(或相乘)时,PyTorch 的广播机制会工作:
    • B 维度匹配。
    • 1 (H) 广播到 H
    • 1 (S_q) 广播到 S_q
    • S_k 维度匹配。
  • 效果:所有对应于 <pad> Token 的 Key(即 S_k 维度上的列)都会被屏蔽掉(在 softmax 之前给它们一个负无穷,或在相乘时设为 0),这样 Query 就不会去“注意”那些无意义的填充位了。

get_causal_mask(size)

  • 目的防止 Decoder “作弊”。在生成(翻译)目标序列时,模型在预测第 t 个词时,只能看到 t 以及 t 之前的词,绝不能看到 t 之后的“未来”词。
  • 原理
    1. torch.tril(torch.ones(1, 1, size, size)) 创建了一个下三角矩阵(对角线及以下为 1,以上为 0)。
    2. 形状为 (1, 1, S, S)
  • 如何工作:这个掩码同样会广播到 (B, H, S, S)。它作用于 Decoder 的第一个 MHA(自注意力)层。
  • 效果:它会屏蔽掉注意力分数矩阵的上三角部分。这意味着,第 i 行的 Query 只能“看到”第 0i 列的 Key,而第 i+1S 列(未来的 Token)全部被屏蔽了。
 REWARD AUTHOR
 Comments
Comment plugin failed to load
Loading comment plugin