LLM-Code-Handwritten

本文主要包含 LLM 代码手撕的实现,包括 RoPE、Attention(MHA、MQA)、Decoder-only、Encoder-only、Transformer 等。

1. RoPE (Rotary Positional Embedding)

核心思想:通过旋转矩阵将绝对位置信息注入到 Q 和 K 中,且具有相对位置特性。
面试重点:复数乘法的实现方式(两两配对旋转)。
为了将位置信息注入模型中,我们将实现 旋转位置编码(RoPE)(Su 等,2021)。
对于一个 token 的 query 向量 $q^{(i)}$,它位于序列中的第 $i$ 个位置,维度为 $d$。
我们会对它施加一个成对旋转矩阵 $R^i$,得到:$q’^{(i)} = R^i q^{(i)} = R^i W_q x^{(i)}$
也就是说,矩阵 $R^i$ 会将 query 向量的每两个元素看作一个 二维向量,并按角度 $\theta_{i,k}$ 进行旋转,其中:$\theta_{i,k} = \frac{i}{\Theta^{(2k-2)/d}} \quad (k = 1, \ldots, d/2)$
$\Theta$ 是一个常数(一般取 10000,与 Transformer 的位置编码一致)。
因此,矩阵 $R^i$ 可以看作一个 分块对角矩阵,每个 2×2 小块为:
$$
R^i_k= \begin{bmatrix} \cos(\theta_{i,k}) & -\sin(\theta_{i,k}) \\
\sin(\theta_{i,k}) & \cos(\theta_{i,k}) \end{bmatrix}
$$

所以完整的旋转矩阵 $R^i$ 是:
$$
R^i= \begin{bmatrix}
R^i_1 & 0 & 0 & … & 0 \\
0 & R^i_2 & 0 & … & 0 \\
0 & 0 & R^i_3 & … & 0 \\
… & … & … & … & … \\
0 & 0 & 0 & … & R^i_{d/2}
\end{bmatrix}
$$

其中所有的 $0$ 都代表 2×2 的零矩阵

虽然我们可以显式构造整个 $d \times d$ 的旋转矩阵,但更高效的做法是利用其特性直接对向量进行旋转。

而且,由于我们只关心一个序列内部 token 的相对旋转关系,所以所有层都可以共享同一套 cos 和 sin 表。 因此,这一层通常用 self.register_buffer(persistent=False) 来保存 cos 和 sin,而不会作为可训练参数。

最终,Q 和 K 都会用对应的$R^i$ 进行旋转。 注意:这个层 没有可学习参数

  • 直接实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn

class Rope(nn.Module):

def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
"""
计算$$\theta_{i,k} = i \cdot \frac{\Theta}{10000^{2k/d}}$$
- 构造 RoPE 模块,并在需要时创建缓存(buffers)。
- `theta`: RoPE 中的 Θ 值(控制旋转角度的频率基底)。
- `d_k`: 查询(query)和键(key)向量的维度。
- `max_seq_len`: 输入序列的最大长度。
- `device`: 存储缓存张量的设备(`torch.device` 或 `None`)。
"""
super().__init__()
if d_k % 2 != 0:
raise ValueError("d_k must be even for RoPE")

self.theta = theta
self.d_k = d_k
self.max_seq_len = max_seq_len
self.device = device

f = 1.0 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))

# position
p = torch.arange(max_seq_len, device=device).float()
# sinusoids
s = torch.outer(p, f)

self.register_buffer("cos_cache", s.cos(), persistent=False)
self.register_buffer("sin_cache", s.sin(), persistent=False)


def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
"""
Args:
x: shape (*, seq_len, d_k), 输入张量(支持任意批维度)
token_positions: shape (*, seq_len), 每个 token 的绝对位置索引

Returns:
shape (*, seq_len, d_k), 应用旋转编码后的输出
"""

cos = self.cos_cache[token_positions]
sin = self.sin_cache[token_positions]

# 分割输入为偶数和奇数维度: x = [x_even, x_odd]
x_even = x[..., 0::2] # 偶数索引: 0, 2, 4, ...
x_odd = x[..., 1::2] # 奇数索引: 1, 3, 5, ...

# 应用旋转公式
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos

# 交错合并: 将 (even, odd) 沿最后一维堆叠并展平
out = torch.stack([out_even, out_odd], dim=-1) # shape: (*, seq_len, d_k//2, 2)
out = out.flatten(-2) # shape: (*, seq_len, d_k)

return out

  • 复数实现
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn

class RoPE(nn.Module):
def __init__(self, dim, max_seq_len=4096, theta=10000.0):
super().__init__()
# 计算频率: theta ^ (-2i/d)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
# 生成时间步 t: [0, 1, ..., max_seq_len-1]
t = torch.arange(max_seq_len, device=freqs.device)
# 外积计算 args: (seq_len, dim/2)
freqs = torch.outer(t, freqs).float()
# 转为极坐标形式,方便后续利用复数性质计算# freqs_cis: (seq_len, dim/2) -> complex64
self.freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

def forward(self, x):# x shape: (batch, seq_len, n_heads, head_dim)# 将 x 重塑为复数形式: (batch, seq_len, n_heads, head_dim/2)
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))

# 获取当前序列长度对应的频率,并利用广播机制
freqs_cis = self.freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)

# 复数乘法即为旋转
x_rotated = x_complex * freqs_cis

# 变回实数并展平: (batch, seq_len, n_heads, head_dim)
x_out = torch.view_as_real(x_rotated).flatten(3)
return x_out.type_as(x)

2. Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import math
from torch import nn

class RoPE(nn.Module):
def __init__(self, dim, max_seq_len, theta):

super.__init__()

f = 1.0 / theta ** (torch.arrange(0, dim, 2).float() / dim)
p = torch.arrange(0, max_seq_len).float()
s = torch.outer(p, f)

self.register_buffer("cos_cache", s.cos(), persistent=False)
self.register_buffer("sin_cache", s.sin(), persistent=False)

def forward(self, x, token_positions):
cos = self.cos_cache[token_positions]
sin = self.sin_cache[token_positions]

x_even = x[..., 0::2]
x_odd = x[..., 1::2]

out_even = cos * x_even - sin * x_odd
out_odd = sin * x_even + cos * x_odd

out = torch.stack([out_even, out_odd], dim=-1)
out = out.flatten(-2)

return out

class ScaledDotProductAttention(nn.Module):
def __init__(self):
super.__init__()

def forward(self, Q, K, V, mask):

# Q, K, V shape (*, batch_size, seq_len, d_k)

d_k = Q.shape[-1]

scale = torch.sqrt(torch.tensor(d_k))
score = torch.matmul(Q, K.transpose(-2, -1)) / scale

attention_weight = torch.softmax(score)

output = torch.matmul(attention_weight, V)

return output

MHA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):

self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads

self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)

max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
# diagonal=1 表示 从主对角线右上方开始 保留 True,其他位置置为 False。
self.register_buffer('causal_mask', causal_mask, persistent=False)

def forward(self, x):
Q = self.w_q(x) # (batch_size, seq_len, d_model)
K = self.w_k(x) # (batch_size, seq_len, d_model)
V = self.w_v(x) # (batch_size, seq_len, d_model)

batch_size, seq_len, _ = x.shape
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# (batch_size, num_heads, seq_len, head_dim)

scale = torch.sqrt(torch.tensor(self.head_dim))
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# (batch_size, num_heads, seq_len, head_dim) * (batch_size, num_heads, head_dim, seq_len)
# -> (batch_size, num_heads, seq_len, seq_len)

# if mask is not None:
# scores = scores.masked_fill(mask == 0, -1e9)

# 1. 根据当前输入的 seq_len 对预定义的 mask 进行切片
# self.causal_mask 的形状是 (max_len, max_len) -> 切片为 (seq_len, seq_len)
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)

# 2. 使用 masked_fill 填充负无穷
# 注意:这里 causal_mask 为 True (上三角) 的地方会被填充 -inf
scores = scores.masked_fill(causal_mask_slice, float('-inf'))

attn_weights = torch.softmax(scores, dim=-1)

context = torch.matmul(attn_weights, V)
# (batch_size, num_heads, seq_len, seq_len) * (batch_size, num_heads, seq_len, head_dim)
# -> (batch_size, num_heads, seq_len, head_dim)

context = context.transpose(1, 2).contiguous() # -> (batch_size, seq_len, head_dim, num_heads)
context = context.view(batch_size, seq_len, self.d_model)

output = self.w_o(context)

return output


MQA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import math

class MultiQueryAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads

# --- 区别 1: 线性层定义 ---
# Q: 保持多头,维度仍为 d_model (即 num_heads * head_dim)
self.w_q = nn.Linear(d_model, d_model)

# K, V: 变为单头 (Shared Head),输出维度仅为 head_dim
# 在 MQA 中,所有的 Query Head 共享同一个 Key 和 Value Head
self.w_k = nn.Linear(d_model, self.head_dim)
self.w_v = nn.Linear(d_model, self.head_dim)

self.w_o = nn.Linear(d_model, d_model)

max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
self.register_buffer('causal_mask', causal_mask, persistent=False)

def forward(self, x):
batch_size, seq_len, _ = x.shape

# 1. 投影
Q = self.w_q(x) # (batch_size, seq_len, d_model) -> (B, L, H * D)
K = self.w_k(x) # (batch_size, seq_len, head_dim) -> (B, L, 1 * D)
V = self.w_v(x) # (batch_size, seq_len, head_dim) -> (B, L, 1 * D)

# 2. 变形 (Reshape & Transpose)
# Q: (B, L, H, D) -> (B, H, L, D)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# --- 区别 2: K, V 的维度处理 ---
# K, V: 这里的头数维度为 1
# (B, L, 1, D) -> (B, 1, L, D)
K = K.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, 1, self.head_dim).transpose(1, 2)

# 3. 计算 Scores (利用广播机制)
scale = math.sqrt(self.head_dim)

# Q: (B, H, L, D)
# K.transpose: (B, 1, D, L)
# PyTorch 会自动将 K 广播(Broadcast)成 (B, H, D, L) 以匹配 Q 的头数
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# Result: (B, H, L, L)

# 4. Masking
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(causal_mask_slice, float('-inf'))

# 5. Softmax
attn_weights = torch.softmax(scores, dim=-1)

# 6. 计算 Context (利用广播机制)
# attn_weights: (B, H, L, L)
# V: (B, 1, L, D) -> 广播为 (B, H, L, D)
context = torch.matmul(attn_weights, V)
# Result: (B, H, L, D)

# 7. 输出投影
context = context.transpose(1, 2).contiguous() # (B, L, H, D)
context = context.view(batch_size, seq_len, self.d_model) # (B, L, d_model)

output = self.w_o(context)

return output

GQA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_heads):
"""
:param d_model: 模型维度
:param num_heads: Query 的头数 (例如 8)
:param num_kv_heads: Key/Value 的头数 (例如 2)。必须能被 num_heads 整除。
"""
super().__init__()

# 检查是否整除
assert num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"

self.d_model = d_model
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = d_model // num_heads

# 计算每个 KV 头对应多少个 Q 头 (Group Size)
# 例如: 8个Q头, 2个KV头 ->每组 4个Q头 共享 1个KV头
self.num_rep = num_heads // num_kv_heads

# Q 保持原样: 输出 num_heads * head_dim
self.w_q = nn.Linear(d_model, num_heads * self.head_dim)

# K, V 减少头数: 输出 num_kv_heads * head_dim
self.w_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.w_v = nn.Linear(d_model, num_kv_heads * self.head_dim)

self.w_o = nn.Linear(d_model, d_model)

max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
self.register_buffer('causal_mask', causal_mask, persistent=False)

def forward(self, x):
batch_size, seq_len, _ = x.shape

# 1. 投影
Q = self.w_q(x) # (B, L, num_heads * D)
K = self.w_k(x) # (B, L, num_kv_heads * D)
V = self.w_v(x) # (B, L, num_kv_heads * D)

# 2. Reshape 分头
# Q: (B, L, num_heads, D)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# K, V: (B, L, num_kv_heads, D) -> (B, num_kv_heads, L, D)
K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

# --- GQA 核心操作: 重复/复制 KV 头 ---
# 目标: 将 K, V 从 (B, num_kv_heads, L, D) 变成 (B, num_heads, L, D) 以便和 Q 进行计算
# 使用 repeat_interleave 在 dim=1 (头维度) 进行复制
# 例如: KV头为 [K1, K2], num_rep=2 -> [K1, K1, K2, K2]
K = K.repeat_interleave(self.num_rep, dim=1)
V = V.repeat_interleave(self.num_rep, dim=1)

# 此时 K, V 的形状变成了 (B, num_heads, L, D),与 MHA 计算逻辑一致了

# 3. 计算 Scores
scale = math.sqrt(self.head_dim)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# (B, H, L, D) @ (B, H, D, L) -> (B, H, L, L)

# 4. Masking
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(causal_mask_slice, float('-inf'))

# 5. Softmax & Context
attn_weights = torch.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V) # (B, H, L, D)

# 6. 输出
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, seq_len, self.d_model)

output = self.w_o(context)

return output

MLA

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
def __init__(self, d_model, num_heads, q_lora_rank=None, kv_lora_rank=None, rope_dim=64):
"""
:param d_model: 模型输入维度
:param num_heads: 注意力头数
:param q_lora_rank: Query 压缩后的秩 (Latent dim),如果为None则不压缩Q
:param kv_lora_rank: KV 压缩后的秩 (Latent dim),这是MLA节省显存的核心
:param rope_dim: 专门用于 RoPE 的维度 (DeepSeek 采用解耦 RoPE 策略)
"""
super().__init__()
self.d_model = d_model
self.num_heads = num_heads

# 1. 维度定义
# content_head_dim 是不包含位置信息的“内容”维度
self.head_dim = d_model // num_heads
self.content_head_dim = self.head_dim - rope_dim
# 内容维度 = 总维度 - RoPE维度
self.rope_dim = rope_dim

# 默认压缩维度设置 (参考 DeepSeek 配置)
if q_lora_rank is None:
q_lora_rank = int(d_model * 0.5) # 仅作示例
if kv_lora_rank is None:
kv_lora_rank = 512 # 通常远小于 num_heads * head_dim

self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank

# --- Query 侧投影 (Q-LoRA) ---
# 1. Q 压缩 (Down Project)
self.w_dq = nn.Linear(d_model, q_lora_rank, bias=False)
self.q_layernorm = nn.LayerNorm(q_lora_rank) # 压缩后通常接 Norm

# 2. Q 解压缩 (Up Project) -> 分为 Content 和 RoPE 两部分
# Content 部分: num_heads * content_head_dim
self.w_uq = nn.Linear(q_lora_rank, num_heads * self.content_head_dim, bias=False)
# RoPE 部分: num_heads * rope_dim (解耦出的位置向量)
self.w_qr = nn.Linear(q_lora_rank, num_heads * self.rope_dim, bias=False)

# --- Key-Value 侧投影 (KV-LoRA) - MLA 的核心 ---
# 1. KV 压缩 (Down Project) -> 这是一个 Latent Vector,KV Cache 只需要存这个!
self.w_dkv = nn.Linear(d_model, kv_lora_rank, bias=False)
self.kv_layernorm = nn.LayerNorm(kv_lora_rank)

# 2. KV 解压缩 (Up Project)
# K Content: num_heads * content_head_dim
self.w_uk = nn.Linear(kv_lora_rank, num_heads * self.content_head_dim, bias=False)
# Value: num_heads * head_dim (Value 不需要 RoPE,所以用完整 head_dim)
self.w_uv = nn.Linear(kv_lora_rank, num_heads * self.head_dim, bias=False)

# 3. K 的 RoPE 部分 (这是直接从 input x 生成的,不经过压缩,或者另外处理)
# DeepSeek-V2 中 K_rope 也是单独生成的,通常所有头共享或者广播
self.w_kr = nn.Linear(d_model, rope_dim, bias=False)

self.w_o = nn.Linear(num_heads * self.head_dim, d_model, bias=False)

# 初始化 RoPE (使用你提供的类,假设传入的是 rope_dim)
# 注意: 这里 max_seq_len 需要足够大
self.rope_layer = RoPE(dim=rope_dim, max_seq_len=2048, theta=10000.0)

max_seq_len = 2048
causal_mask = torch.triu(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool), diagonal=1)
self.register_buffer('causal_mask', causal_mask, persistent=False)

def forward(self, x, token_positions=None):
batch_size, seq_len, _ = x.shape
if token_positions is None:
token_positions = torch.arange(seq_len, device=x.device)

# --- 1. Query 生成 (Q-LoRA) ---
# 压缩 -> Norm
c_Q = self.q_layernorm(self.w_dq(x)) # (B, L, q_lora_rank)

# 解压 -> Content (B, L, H, D_content)
q_content = self.w_uq(c_Q).view(batch_size, seq_len, self.num_heads, self.content_head_dim)
# 解压 -> RoPE part (B, L, H, D_rope)
q_rope = self.w_qr(c_Q).view(batch_size, seq_len, self.num_heads, self.rope_dim)

# --- 2. Key-Value 生成 (KV-LoRA / MLA) ---
# 压缩 -> Norm (这就是推理时 KV Cache 真正存储的东西,非常小)
c_KV = self.kv_layernorm(self.w_dkv(x)) # (B, L, kv_lora_rank)

# 解压 -> K Content (B, L, H, D_content)
k_content = self.w_uk(c_KV).view(batch_size, seq_len, self.num_heads, self.content_head_dim)
# 解压 -> Value (B, L, H, D_head)
v = self.w_uv(c_KV).view(batch_size, seq_len, self.num_heads, self.head_dim)

# 生成 -> K RoPE part (B, L, 1, D_rope)
# DeepSeek 论文中 k_rope 通常是 shared (head_dim=1) 这里的 Linear 输出是 rope_dim
k_rope = self.w_kr(x).view(batch_size, seq_len, 1, self.rope_dim)
# 广播 k_rope 到所有头以便和 q_rope 计算: (B, L, 1, D_r) -> (B, L, H, D_r)
k_rope = k_rope.expand(-1, -1, self.num_heads, -1)

# --- 3. 应用 RoPE (仅对 RoPE 部分) ---
# 调用你的 RoPE 类。注意我们需要先 reshape 把 head 维度折叠进 batch 或者 loop 处理
# 为了简单适配你的 RoPE 类 (它接受 [..., dim]),我们这里手动处理一下维度
# q_rope: (B, L, H, D_r) -> 调整为 (B*H, L, D_r) 传入 RoPE 再变回来

q_rope_flat = q_rope.transpose(1, 2).reshape(-1, seq_len, self.rope_dim)
k_rope_flat = k_rope.transpose(1, 2).reshape(-1, seq_len, self.rope_dim)

# 应用旋转
q_rope_rotated = self.rope_layer(q_rope_flat, token_positions)
k_rope_rotated = self.rope_layer(k_rope_flat, token_positions)

# 变回 (B, H, L, D_r) 方便后续 concat
q_rope = q_rope_rotated.view(batch_size, self.num_heads, seq_len, self.rope_dim).transpose(1, 2)
k_rope = k_rope_rotated.view(batch_size, self.num_heads, seq_len, self.rope_dim).transpose(1, 2)

# --- 4. 拼接 (Concat) ---
# DeepSeek 的策略:Q = [q_content, q_rope], K = [k_content, k_rope]
# 最终维度恢复为 (B, L, H, content + rope) = (B, L, H, head_dim)
q = torch.cat([q_content, q_rope], dim=-1)
k = torch.cat([k_content, k_rope], dim=-1)

# 转置为 (B, H, L, D) 准备做 Attention
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# --- 5. 标准 Attention 计算 ---
scale = math.sqrt(self.head_dim)
scores = torch.matmul(q, k.transpose(-2, -1)) / scale

# Masking
causal_mask_slice = self.causal_mask[:seq_len, :seq_len].unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(causal_mask_slice, float('-inf'))

attn_weights = torch.softmax(scores, dim=-1)

context = torch.matmul(attn_weights, v) # (B, H, L, D)

context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
output = self.w_o(context)

return output

Explanation

Generated By Gemini-3-Pro

MLA(Multi-Head Latent Attention)不仅仅是一个代码技巧,它是 DeepSeek-V2/V3 能够以极低的显存占用(KV Cache)实现媲美 LLaMA3 MHA(多头注意力)性能的核心秘密。

要真正理解 MLA,我们需要解决两个主要矛盾:

  1. 性能矛盾:想效果好,就得用 MHA(多头),但显存爆炸。

  2. 显存矛盾:想省显存,就得用 MQA/GQA(少头),但模型表达能力可能下降。

MLA 的核心原理就是:用“矩阵分解”的数学技巧,在“显存”上模拟 MQA,在“计算”上模拟 MHA。

下面分三个关键模块进行详细拆解:

第一模块:低秩键值联合压缩 (Low-Rank Key-Value Joint Compression)

这是 MLA 的地基。

在传统的 Standard MHA 中,通过线性层把输入 $x$ 映射为 $K$ 和 $V$:

$$\mathbf{k} = x W_K, \quad \mathbf{v} = x W_V$$

假设 hidden_dim 是 4096,batch_size 和 seq_len 很大,生成的 $K$ 和 $V$ 矩阵非常巨大,都要存进显存(KV Cache)。

MLA 的思路是:$K$ 和 $V$ 矩阵其实有很多冗余信息(低秩性)。为什么不先把输入压缩成一个很小的“潜在向量”(Latent Vector),存 KV Cache 时只存这个小向量,计算时再还原回去呢?

  1. Down-Projection (压缩):

    输入 $x$ 先经过一个压缩矩阵 $W_{DKV}$,变成一个低维向量 $c_{KV}$(Latent Vector)。

    $$c_{KV} = x W_{DKV}$$

    • 关键点:推理时,KV Cache 只存这个 $c_{KV}$。它的维度(比如 512)远小于标准 KV 的维度(比如 num_heads 128 * head_dim 128 = 16384)。这使得显存占用极低。
  2. Up-Projection (还原):

    在计算 Attention Score 时,通过两个上投影矩阵 $W_{UK}$ 和 $W_{UV}$,把 $c_{KV}$ 还原成多头的形式。

    $$\mathbf{k} = c_{KV} W_{UK}, \quad \mathbf{v} = c_{KV} W_{UV}$$

此时的问题:如果仅仅是这样,计算量并没有减少,推理时还得实时把 $c_{KV}$ 乘回大矩阵再算 Attention,效率不高。于是有了下面的“矩阵吸收”。


第二模块:推理时的矩阵吸收 (Matrix Absorption)

这是 MLA 最骚的操作。利用矩阵乘法的结合律,我们不需要在推理时真正把 $K$ 还原出来。

回顾 Attention 的核心公式(忽略 Scale 和 Softmax):

$$Score = Q \cdot K^T$$

在 MLA 中,$K$ 是由 $c_{KV}$ 还原来的,代入公式:

$$Score = Q \cdot (c_{KV} \cdot W_{UK})^T = Q \cdot W_{UK}^T \cdot c_{KV}^T$$

结合律魔法:

我们可以把 $(Q \cdot W_{UK}^T)$ 这一项结合在一起!

$$Score = \underbrace{(Q \cdot W_{UK}^T)}{\text{Absorbed Query}} \cdot c{KV}^T$$

这一步意味着什么?

  • 在推理时,我们计算出 Query 后,直接把 $W_{UK}^T$(Key 的上投影矩阵)吸收(融合) 进 Query 里。

  • 这就变成了:一个变形后的 Query,直接去和 KV Cache 里存储的那个极小的 $c_{KV}$ 做点积。

  • 结果:我们完全不需要在显存里复原那个巨大的 Key 矩阵。

这也是为什么 MLA 被称为“伪装成 MQA 的 MHA” —— 它的存储量像 MQA 一样小,但它的权重矩阵 $W_{UK}$ 依然保留了多头的信息。


第三模块:解耦旋转位置编码 (Decoupled RoPE)

既然上面的“矩阵吸收”这么完美,为什么代码里还要把 RoPE 单独拆出来(peeled off)?

因为 RoPE 破坏了矩阵结合律。

RoPE 是对位置 $m$ 和 $n$ 进行旋转变换 $\mathcal{R}$。如果直接对压缩后的向量 $c_{KV}$ 或者还原后的 $K$ 加 RoPE,公式会变成:

$$Score = Q \cdot \text{RoPE}(c_{KV} \cdot W_{UK})^T$$

这里的 $\text{RoPE}$ 操作是非线性的(相对于矩阵乘法顺序而言),它夹在中间,导致 $W_{UK}$ 无法被 $Q$ 吸收。

DeepSeek 的解决方案:切分(Decouple)

它把 Query 和 Key 的向量拆成了两部分:

  1. Content Part (内容向量):承载语义信息,不加 RoPE。这部分可以使用上面的“矩阵吸收”技巧,极致压缩。

  2. RoPE Part (位置向量):承载位置信息,必须加 RoPE

具体做法

  • 生成的 Key ($k_{total}$) = [$k_{content}$, $k_{rope}$]

  • $k_{content}$ 来自压缩的 $c_{KV}$(走矩阵吸收逻辑)。

  • $k_{rope}$ 单独生成(通常头数很少,甚至像代码里那样共享头,即 head_dim=1,然后广播),只用来携带位置信息。

  • Query 同理拆分。

计算 Attention Score 时,利用数学性质:

$$Q^T K = (q_{content}^T k_{content}) + (q_{rope}^T k_{rope})$$

这意味着:我们算一部分“内容分”,算一部分“位置分”,加起来就是总分。

  • 内容分:享受极致显存压缩。

  • 位置分:虽然不压缩,但维度很小(如代码中的 64 维),且共享头,显存占用忽略不计。


总结:MLA 的本质图景

做一个简单的对比:

特性 Standard MHA GQA (LLaMA-2/3) DeepSeek MLA
KV Cache 存什么 存完整的 num_heads $\times$ d_head 存较少的 num_kv_heads $\times$ d_head 只存 1 个 latent 向量 + 少量 RoPE 向量
显存占用 巨大 (100%) 中等 (比如 1/8) 极小 (比 GQA 还小)
模型容量 强 (每个头都有独立参数) 弱 (KV 头参数减少了) 强 (投影矩阵 $W_{UK}$ 依然是多头的)
推理技巧 矩阵吸收 (Absorbed Key)
位置编码 直接作用于 K 直接作用于 K Decoupled (分离内容与位置)

一句话概括 MLA 的原理:

通过低秩压缩减少 KV 存储,利用矩阵吸收在推理时跳过解压步骤,最后通过解耦 RoPE 绕过位置编码对矩阵结合律的破坏,从而实现了“用 MQA 的显存,跑出 MHA 的性能”。

3. Utils

Softmax

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
"""
Applies the softmax function along the specified dimension.
exp(x - max(x)) / sum(exp(x - max(x)))
Args:
x (torch.Tensor): Input tensor.
dim (int): Dimension along which to apply softmax.

Returns:
torch.Tensor: Tensor after applying softmax.
"""

# x: 任意形状的张量
# dim: 在哪个维度上进行 softmax(比如 dim=-1 表示对最后一维做 softmax)

# 1) 在指定维度上取最大值,用来做数值稳定性处理
# keepdim=True 的作用是保持维度不变,方便后续做广播减法
x_max = x.max(dim=dim, keepdim=True)[0] # .max 返回 (values, indices),所以取 [0] 获得最大值

# 2) 减去最大值后再做 exp,避免 exp 时数值溢出(即变得特别大)
x_exp = torch.exp(x - x_max)

# 3) 对每个向量求和,然后归一化得到 softmax 概率分布
# 同样使用 keepdim=True 保持维度不变,确保能正确广播除法
return x_exp / x_exp.sum(dim=dim, keepdim=True)

AdamW

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import math
import torch

class AdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8):
defaults = {
"lr": lr,
"weight_decay": weight_decay,
"betas": betas,
"eps": eps
}
super().__init__(params, defaults)

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
lr = group['lr']
weight_decay = group['weight_decay']
beta1, beta2 = group['betas']
eps = group['eps']

for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')

# 1. 得到梯度g
grad = p.grad.data
state = self.state[p]

# state init
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p.data)
state["exp_avg_sq"] = torch.zeros_like(p.data)


# 2. 得到一阶矩m
# 3. 得到二阶矩v
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]

# update biased first and second moment estimates
# exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
# exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))

new_exp_avg = exp_avg * beta1 + grad * (1.0 - beta1)
new_exp_avg_sq = exp_avg_sq * beta2 + (grad * grad) * (1.0 - beta2)

# write back state with new tensors (non-inplace)
state["exp_avg"] = new_exp_avg
state["exp_avg_sq"] = new_exp_avg_sq

# increment step and compute bias-corrected learning rate
state["step"] += 1
step = state["step"]

# 4. 计算偏差修正后的学习率
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step

# compute step_size as in original Adam: lr * sqrt(1 - beta2^t) / (1 - beta1^t)
step_size = lr * (math.sqrt(bias_correction2) / bias_correction1)

# 5. 更新参数
denom = torch.sqrt(new_exp_avg_sq) + eps
update = new_exp_avg / denom
# parameter update: θ = θ - step_size * exp_avg / denom
new_p_data = p.data - step_size * update

# 6. 权重衰减
# decoupled weight decay (AdamW): θ = θ - lr * weight_decay * θ
if weight_decay != 0:
new_p_data = new_p_data - lr * weight_decay * p.data

# write back parameter with new data
p.data = new_p_data

return loss

LR

1
2
3
4
5
6
7
8
import math

def learning_rate_schedule(it, max_learning_rate, min_learning_rate, warmup_iters, cosine_cycle_iters):
if it < warmup_iters:
return max_learning_rate * it / warmup_iters
if warmup_iters <= it <= cosine_cycle_iters:
return min_learning_rate + (max_learning_rate - min_learning_rate) * 0.5 * (1 + math.cos(math.pi * (it - warmup_iters) / (cosine_cycle_iters - warmup_iters)))
return min_learning_rate

RMSNorm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Implement RMSNorm as a torch.nn.Module
"""

import torch
from torch import nn

class RMSNorm(nn.Module):

def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
"""
Construct the RMSNorm module. This function should accept the following parameters:
- d_model: int Hidden dimension of the model
- eps: float = 1e-5 Epsilon value for numerical stability
- device: torch.device | None = None Device to store the parameters on
- dtype: torch.dtype | None = None Data type of the parameters
"""
super().__init__()
self.d_model = d_model
self.eps = eps
fractory_kwargs = {"device": device, "dtype": dtype}

# 创建可学习的缩放权重参数,形状为 (d_model,),初始化为全 1
# 每个特征维度有一个对应的缩放因子
weight_shape = (d_model,)
weight_tensor = torch.ones(weight_shape, **fractory_kwargs) # g_i
self.weight = nn.Parameter(weight_tensor)


def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype

x = x.to(dtype=torch.float32)

# 将输入提升到 float32 以保证归一化过程中的数值精度
x = x.to(torch.float32)

# 计算最后一个维度(特征维度)上的均方值(即方差,但不减均值)
# shape: (batch_size, sequence_length, 1)
variance = x.pow(2).mean(-1, keepdim=True)

# 使用均方根(RMS)进行归一化:x / sqrt(variance + eps)
# torch.rsqrt 是 1 / sqrt(x) 的高效实现
x = x * torch.rsqrt(variance + self.eps)

# 将归一化后的结果与可学习的权重相乘(逐通道缩放)
# 权重会自动广播到 batch 和 sequence 维度
x = self.weight * x

# 将结果转换回原始输入的数据类型(例如从 float32 回 float16)
return x.to(input_dtype)

Author

Alan Zeng

Posted on

2025-12-15

Updated on

2026-01-19

Licensed under

Comments