CS336-HW1

本文为对Stanford CS336 Assignment 1 的实现

Chapter #2 Byte-Pair Encoding (BPE) Tokenizer

BPE Tokenizer Training

我来用一个简单的例子为你详细解释 BPE 算法,以及 word_freqspair_freqs 的含义。

📚 BPE 算法详解 - 简单例子

一个简单的例子来解释整个算法流程:

假设输入文本

1
"aab aab ab"

📝 第一步:初始化

  1. word_freqs(词频表)

记录每个词(word)被分解成字节序列后的频率

1
2
3
4
word_freqs = {
(b'a', b'a', b'b'): 2, # "aab" 出现 2 次
(b'a', b'b'): 1 # "ab" 出现 1 次
}

解释​:

  • 键:词的字节序列,每个字节独立存储为 tuple
  • 值:这个词在文本中出现的次数
  1. pair_freqs(字节对频率表)

记录所有相邻字节对在整个语料中的出现频率

1
2
3
4
pair_freqs = {
(b'a', b'a'): 2, # "aa" 在 "aab"(2次) 中出现 2 次
(b'a', b'b'): 3 # "ab" 在 "aab"(2次) + "ab"(1次) 中出现 3 次
}

解释​:

  • 键:相邻的两个字节
  • 值:这个字节对在所有词中的总出现次数

🔄 第二步:迭代合并过程

迭代 1:合并频率最高的 pair

  1. 找到最高频 pair
1
best_pair = (b'a', b'b')  # 频率 = 3(最高)
  1. 合并这个 pair

将所有包含 (b'a', b'b') 的词进行合并:

1
2
3
4
5
6
7
8
9
10
11
# 更新前
word_freqs = {
(b'a', b'a', b'b'): 2, # 包含 ab
(b'a', b'b'): 1 # 包含 ab
}

# 更新后(b'a' + b'b' → b'ab')
word_freqs = {
(b'a', b'ab'): 2, # "aab" → "a"+"ab"
(b'ab',): 1 # "ab" → "ab"
}
  1. 更新 pair_freqs

删除旧的 pairs:

  • (b'a', b'a', b'b') 中删除:
    • (b'a', b'a') 频率 -2 → 0(删除)
    • (b'a', b'b') 频率 -2 → 1
  • (b'a', b'b') 中删除:
    • (b'a', b'b') 频率 -1 → 0(删除)

添加新的 pairs:

  • (b'a', b'ab') 中添加:
    • (b'a', b'ab') 频率 +2 → 2
1
2
3
4
# 更新后的 pair_freqs
pair_freqs = {
(b'a', b'ab'): 2 # 新的最高频 pair
}

迭代 2:继续合并

  1. 找到最高频 pair
1
best_pair = (b'a', b'ab')  # 频率 = 2
  1. 合并这个 pair
1
2
3
4
5
6
7
8
9
10
11
# 更新前
word_freqs = {
(b'a', b'ab'): 2,
(b'ab',): 1
}

# 更新后(b'a' + b'ab' → b'aab')
word_freqs = {
(b'aab',): 2, # "a"+"ab" → "aab"
(b'ab',): 1 # 不受影响
}
  1. 更新 pair_freqs
1
2
# 删除 (b'a', b'ab'),没有新的 pairs 产生(单个 token)
pair_freqs = {} # 空了,停止合并

完整流程可视化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
初始状态:
文本: "aab aab ab"
├─ word_freqs: {(b'a',b'a',b'b'): 2, (b'a',b'b'): 1}
└─ pair_freqs: {(b'a',b'a'): 2, (b'a',b'b'): 3}

迭代 1: 合并 (b'a', b'b')
├─ word_freqs: {(b'a',b'ab'): 2, (b'ab',): 1}
└─ pair_freqs: {(b'a',b'ab'): 2}

迭代 2: 合并 (b'a', b'ab')
├─ word_freqs: {(b'aab',): 2, (b'ab',): 1}
└─ pair_freqs: {} ← 没有更多 pairs,停止

最终 merges: [(b'a',b'b'), (b'a',b'ab')]
最终 vocab: {
0: b'a', 1: b'b', ..., 255: b'\xff', # 初始 256 个字节
256: b'ab', # 第一次合并
257: b'aab' # 第二次合并
}

为什么要增量更新?

原始方法(慢)

只使用merge_pairsget_pairs

后续每一次更新都去做一次完整的merge_pairs, 并在下一次迭代时获取一次get_pairs

实际应该仅仅在初始化时候获取一次, 接下来只对需要更新的部分进行更新

1
2
3
4
for 每次迭代:
pairs = get_pairs(word_freqs) # 重新计算所有 pairs O(N×M)
best_pair = 找最高频 pair
word_freqs = merge_pairs(best_pair, word_freqs)

优化方法(快)

1
2
3
4
pair_freqs = get_pairs(word_freqs)  # 只计算一次
for 每次迭代:
best_pair = 从 pair_freqs 中找最高频 pair # O(P)
word_freqs, pair_freqs = update_pairs_incremental(...) # 只更新受影响的

关键优势​:

  • 只在第一次计算全部 pairs

  • 之后每次只更新被 merge 影响的词和 pairs

  • 时间复杂度从

    $O(N×M×iterations)$

    降到

    $O(N×M + iterations×affected_words)$

code

原始代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
'''
Problem (train_bpe): BPE Tokenizer Training (15 points)

Deliverable: Write a function that, given a path to an input text file, trains a (byte-level) BPE
tokenizer. Your BPE training function should handle (at least) the following input parameters:
- input_path: str Path to a text file with BPE tokenizer training data.
- vocab_size: int A positive integer that defines the maximum final vocabulary size (including the
initial byte vocabulary, vocabulary items produced from merging, and any special tokens).
- special_tokens: list[str] A list of strings to add to the vocabulary. These special tokens do not
otherwise affect BPE training.

Your BPE training function should return the resulting vocabulary and merges:
- vocab: dict[int, bytes] The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes).
- merges: list[tuple[bytes, bytes]] A list of BPE merges produced from training. Each list item
is a tuple of bytes (<token1>, <token2>), representing that <token1> was merged with
<token2>. The merges should be ordered by order of creation.
To test your BPE training function against our provided tests, you will first need to implement the
test adapter at [adapters.run_train_bpe]. Then, run uv run pytest tests/test_train_bpe.py.
Your implementation should be able to pass all tests. Optionally (this could be a large time-investment),
you can implement the key parts of your training method using some systems language, for instance
C++ (consider cppyy for this) or Rust (using PyO3). If you do this, be aware of which operations
require copying vs reading directly from Python memory, and make sure to leave build instructions, or
make sure it builds using only pyproject.toml. Also note that the GPT-2 regex is not well-supported
in most regex engines and will be too slow in most that do. We have verified that Oniguruma is
reasonably fast and supports negative lookahead, but the regex package in Python is, if anything,
even faster.
'''

from collections import defaultdict
import os
from typing import Counter
import re
import regex

# 处理步骤

def split_on_special_tokens(
text: str,
special_tokens: list[str]
) -> list[str]:
if not special_tokens:
return [text]
escaped_tokens = [re.escape(token) for token in special_tokens]
pattern = "|".join(escaped_tokens)

chunks = re.split(pattern, text)
chunks = [chunk for chunk in chunks if chunk]

return chunks

def get_pairs(
word_freqs: dict[tuple[bytes, ...], int]
) -> dict[tuple[bytes, bytes], int]:

pairs = defaultdict(int)
# 默认列表:
# 当尝试访问一个不存在的键时,普通的字典会抛出一个 KeyError 错误。而 defaultdict 不会抛出错误,它会自动为该键生成一个默认值。
# 工厂函数是 int: 当调用 int() 时,它会返回整数的默认值,也就是 0。
for w, f in word_freqs.items():
for i in range(len(w) - 1):
pair = (w[i], w[i + 1])
pairs[pair] += f

return dict(pairs)

def merge_pairs(
best_pair: tuple[bytes, bytes],
word_freqs: dict[tuple[bytes, ...], int]
) -> dict[tuple[bytes, ...], int]:
new_word_freqs = {}
merged = best_pair[0] + best_pair[1]

for w, f in word_freqs.items():
new_word = []
i = 0

while i < len(w):
if i < len(w) - 1 and w[i] == best_pair[0] and w[i+1] == best_pair[1]:
new_word.append(merged)
i += 2
else:
new_word.append(w[i])
i += 1

new_word_freqs[tuple(new_word)] = f

return new_word_freqs

def train_bpe(
input_path: str,
vocab_size: int,
special_tokens: list[str],
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
# assert isinstance(special_tokens, bytes), "Must represent special token as a bytestring"

# Pre-tokenization
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

with open(input_path, "r", encoding="utf-8") as f:
text = f.read()

# split_on_special_tokens
chunks = split_on_special_tokens(text=text, special_tokens=special_tokens)

# Get frequency table
word_freqs_str = Counter()

for chunk in chunks:
chunk_freqs = Counter(
match.group() for match in regex.finditer(PAT, chunk)
)
word_freqs_str.update(chunk_freqs)

word_freqs = {}
for w, f in word_freqs_str.items():
word_byte = tuple(bytes([b]) for b in w.encode("utf-8"))
word_freqs[word_byte] = f

# 初始化Vocabulary
vocab = {i: bytes([i]) for i in range(256)}
next_id = 256

for t in special_tokens:
vocab[next_id] = t.encode("utf-8")
next_id += 1

# print(f"initial vocabulary size: {len(vocab)}")
# print("===step: Train BPE (Merge)===")
# print(f"initial word_freqs: {word_freqs}")

merge_times = vocab_size - len(vocab)
merges = []

for i in range(merge_times):
pairs = get_pairs(word_freqs)

if not pairs:
# print(f"没有更多字节对可以合并,在第 {i} 次合并后停止")
break

max_freq = max(pairs.values())
best_pair = max([p for p, f in pairs.items() if f == max_freq])
# print(f"> best_pair: {best_pair}")
merges.append(best_pair)
vocab[next_id] = best_pair[0] + best_pair[1]
next_id += 1

word_freqs = merge_pairs(best_pair, word_freqs)
# print(f"> word_freqs: {word_freqs}")
# print(f"--------------------------------")

return vocab, merges

if __name__ == "__main__":
input_path = "data/test_simple.txt"
vocab, merges = train_bpe(
input_path=input_path,
vocab_size=300,
# special_tokens=[]
special_tokens=["<|endoftext|>"]
)
# print(f"vocab: {vocab}")
# print(f"merges: {merges}")
# print(f"vocab size: {len(vocab)}")

# print("\n前16次合并:")
# for i, (left, right) in enumerate(merges[:15]):
# left_str = left.decode("utf-8", errors="replace")
# right_str = right.decode("utf-8", errors="replace")
# merged = (left + right).decode("utf-8", errors="replace")
# print(f"{i+1}. ({repr(left_str)}, {repr(right_str)}) → {repr(merged)} \n==== merges ==== {merges[i]}")
优化代码(增量更新)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
'''
Problem (train_bpe): BPE Tokenizer Training (15 points)

Deliverable: Write a function that, given a path to an input text file, trains a (byte-level) BPE
tokenizer. Your BPE training function should handle (at least) the following input parameters:
- input_path: str Path to a text file with BPE tokenizer training data.
- vocab_size: int A positive integer that defines the maximum final vocabulary size (including the
initial byte vocabulary, vocabulary items produced from merging, and any special tokens).
- special_tokens: list[str] A list of strings to add to the vocabulary. These special tokens do not
otherwise affect BPE training.

Your BPE training function should return the resulting vocabulary and merges:
- vocab: dict[int, bytes] The tokenizer vocabulary, a mapping from int (token ID in the vocabulary) to bytes (token bytes).
- merges: list[tuple[bytes, bytes]] A list of BPE merges produced from training. Each list item
is a tuple of bytes (<token1>, <token2>), representing that <token1> was merged with
<token2>. The merges should be ordered by order of creation.
To test your BPE training function against our provided tests, you will first need to implement the
test adapter at [adapters.run_train_bpe]. Then, run uv run pytest tests/test_train_bpe.py.
Your implementation should be able to pass all tests. Optionally (this could be a large time-investment),
you can implement the key parts of your training method using some systems language, for instance
C++ (consider cppyy for this) or Rust (using PyO3). If you do this, be aware of which operations
require copying vs reading directly from Python memory, and make sure to leave build instructions, or
make sure it builds using only pyproject.toml. Also note that the GPT-2 regex is not well-supported
in most regex engines and will be too slow in most that do. We have verified that Oniguruma is
reasonably fast and supports negative lookahead, but the regex package in Python is, if anything,
even faster.
'''

from collections import defaultdict
from typing import Counter
import re
import regex

# 处理步骤

def split_on_special_tokens(
text: str,
special_tokens: list[str]
) -> list[str]:
if not special_tokens:
return [text]
escaped_tokens = [re.escape(token) for token in special_tokens]
pattern = "|".join(escaped_tokens)

chunks = re.split(pattern, text)
chunks = [chunk for chunk in chunks if chunk]

return chunks

def get_pairs(
word_freqs: dict[tuple[bytes, ...], int]
) -> dict[tuple[bytes, bytes], int]:
"""计算所有相邻字节对的频率"""
pairs = defaultdict(int)
for w, f in word_freqs.items():
for i in range(len(w) - 1):
pair = (w[i], w[i + 1])
pairs[pair] += f
return dict(pairs)

def update_pairs_incremental(
best_pair: tuple[bytes, bytes],
word_freqs: dict[tuple[bytes, ...], int],
pair_freqs: dict[tuple[bytes, bytes], int]
) -> tuple[dict[tuple[bytes, ...], int], dict[tuple[bytes, bytes], int]]:
"""增量更新:只更新受影响的词和 pairs"""
merged = best_pair[0] + best_pair[1]
new_word_freqs = {}

# 直接修改 pair_freqs,不做复制(更高效)
# 处理每个包含 best_pair 的词
for w, f in word_freqs.items():
# 检查这个词是否包含要合并的 pair
has_pair = False
for i in range(len(w) - 1):
if w[i] == best_pair[0] and w[i+1] == best_pair[1]:
has_pair = True
break

if not has_pair:
# 这个词不受影响,直接保留
new_word_freqs[w] = f
else:
# 需要移除旧词的所有 pairs
for i in range(len(w) - 1):
old_pair = (w[i], w[i + 1])
pair_freqs[old_pair] -= f
if pair_freqs[old_pair] <= 0:
del pair_freqs[old_pair]

# 合并 pair 生成新词
new_word = []
i = 0
while i < len(w):
if i < len(w) - 1 and w[i] == best_pair[0] and w[i+1] == best_pair[1]:
new_word.append(merged)
i += 2
else:
new_word.append(w[i])
i += 1

new_word_tuple = tuple(new_word)
new_word_freqs[new_word_tuple] = f

# 添加新词的 pairs
for i in range(len(new_word) - 1):
new_pair = (new_word[i], new_word[i + 1])
pair_freqs[new_pair] = pair_freqs.get(new_pair, 0) + f

return new_word_freqs, pair_freqs

def merge_pairs(
best_pair: tuple[bytes, bytes],
word_freqs: dict[tuple[bytes, ...], int]
) -> dict[tuple[bytes, ...], int]:
new_word_freqs = {}
merged = best_pair[0] + best_pair[1]

for w, f in word_freqs.items():
new_word = []
i = 0

while i < len(w):
if i < len(w) - 1 and w[i] == best_pair[0] and w[i+1] == best_pair[1]:
new_word.append(merged)
i += 2
else:
new_word.append(w[i])
i += 1

new_word_freqs[tuple(new_word)] = f

return new_word_freqs

def train_bpe(
input_path: str,
vocab_size: int,
special_tokens: list[str],
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
# assert isinstance(special_tokens, bytes), "Must represent special token as a bytestring"

# Pre-tokenization
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

with open(input_path, "r", encoding="utf-8") as f:
text = f.read()

# split_on_special_tokens
chunks = split_on_special_tokens(text=text, special_tokens=special_tokens)

# Get frequency table
word_freqs_str = Counter()

for chunk in chunks:
chunk_freqs = Counter(
match.group() for match in regex.finditer(PAT, chunk)
)
word_freqs_str.update(chunk_freqs)

word_freqs = {}
for w, f in word_freqs_str.items():
word_byte = tuple(bytes([b]) for b in w.encode("utf-8"))
word_freqs[word_byte] = f

# 初始化Vocabulary
vocab = {i: bytes([i]) for i in range(256)}
next_id = 256

for t in special_tokens:
vocab[next_id] = t.encode("utf-8")
next_id += 1

# print(f"initial vocabulary size: {len(vocab)}")
# print("===step: Train BPE (Merge)===")
# print(f"initial word_freqs: {word_freqs}")

merge_times = vocab_size - len(vocab)
merges = []

# 初始化 pair 频率字典(只计算一次)
pair_freqs = get_pairs(word_freqs)

for i in range(merge_times):
if not pair_freqs:
# print(f"没有更多字节对可以合并,在第 {i} 次合并后停止")
break

# 找到频率最高的 pair
max_freq = max(pair_freqs.values())
best_pair = max([p for p, f in pair_freqs.items() if f == max_freq])
# print(f"> best_pair: {best_pair}")

merges.append(best_pair)
vocab[next_id] = best_pair[0] + best_pair[1]
next_id += 1

# 使用增量更新而不是重新计算所有 pairs
word_freqs, pair_freqs = update_pairs_incremental(best_pair, word_freqs, pair_freqs)
# print(f"> word_freqs: {word_freqs}")
# print(f"--------------------------------")

return vocab, merges

if __name__ == "__main__":
input_path = "data/test_simple.txt"
vocab, merges = train_bpe(
input_path=input_path,
vocab_size=300,
# special_tokens=[]
special_tokens=["<|endoftext|>"]
)
# print(f"vocab: {vocab}")
# print(f"merges: {merges}")
# print(f"vocab size: {len(vocab)}")

# print("\n前16次合并:")
# for i, (left, right) in enumerate(merges[:15]):
# left_str = left.decode("utf-8", errors="replace")
# right_str = right.decode("utf-8", errors="replace")
# merged = (left + right).decode("utf-8", errors="replace")
# print(f"{i+1}. ({repr(left_str)}, {repr(right_str)}) → {repr(merged)} \n==== merges ==== {merges[i]}")
继续优化(heapq)

Chapter #3 Transformer Language Model Architecture

Linear

简单版本 worked

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn
from einops import einsum, rearrange

class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, device=None, dtype=None):
"""
Constructs a linear transformation module without bias.

Args:
in_features (int): Size of each input sample.
out_features (int): Size of each output sample.
device (torch.device | None): Device to store the parameters on.
dtype (torch.dtype | None): Data type of the parameters.
"""
super().__init__()

self.in_features = in_features
self.out_features = out_features
self.device = device
self.dtype = dtype
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))

std = (2 / (in_features + out_features)) ** 0.5
nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies the linear transformation to the input: y = x @ W.T

Args:
x (torch.Tensor): Input tensor of shape (... , in_features)

Returns:
torch.Tensor: Output tensor of shape (... , out_features)
"""

return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")

这部分代码与nn.Linear()区别

  • 区别一:偏置 (Bias) - 最大的区别
    • nn.Linear​ (默认): 默认情况下,它有一个偏置 (bias) 参数 self.bias, 即

      $y = xW^T + b$

    • 这里的代码:没有实现偏置项。计算公式是:

      $y = xW^T$

    • 此处代码在功能上并不等同于 nn.Linear(...),而是等同于 ​**nn.Linear(..., bias=False)**​。

  • 区别二:初始化 (Initialization) - 较小的区别
    • nn.Linear​ (默认): 它在 reset_parameters 方法中使用的是 Kaiming (He) 初始化 (nn.init.kaiming_uniform_)。这是为 ReLU 激活函数专门优化的。
    • 这里的代码: 您明确使用了 Xavier (Glorot) 初始化 (nn.init.trunc_normal_ 配合 std=sqrt(2/(in+out)))。这通常是为 tanhsigmoid 激活函数优化的。

Key notes:

  1. torch.empty( (out_features, in_features) , **factory_kwargs )
    它的含义是:

    1. 调用 ​torch.empty​ 函数 (外层 ())
    2. 传入第一个参数: 形状 (shape)。这个形状是一个元组 (out_features, in_features) (内层 ())。
    3. 传入第二个“参数”: **factory_kwargs。这不是一个参数,而是“解包”操作。它会把 {"device": device, "dtype": dtype} 这个字典拆开,变成 device=device, dtype=dtype 作为后续的命名参数传给 torch.empty
  2. std = (2 / (in_features+out_features)) ** 0.5 nn.init.trunc_normal_(self.weight, std=std, a=-3*std, b=3*std)

    1. Xavier (泽维尔) 初始化

      $std(W) = \sqrt{Var(W)} = \sqrt{\frac{2}{n_{in} + n_{out}}}$

    2. trunc_normal_是一个更安全的随机数生成器(截断正态分布)它在生成随机数时,会设定一个“边界” 在这里 a=-3*stdb=3*std 就是这个边界

    3. _​ (下划线): ​是 PyTorch 的一个约定,意思是 ​**”in-place” (原地操作)*​。不会返回一个新的张量,而是会直接修改* self.weight 张量本身,用新的随机数填充它。

  3. return einsum(x, self.weight, "... in_features, out_features in_features -> ... out_features")

    1. einsum 的核心是字符串,格式是:输入1, 输入2, … -> 输出
    2. 等同于以下代码
    1
    2
    3
    4
    5
    6
    7
    8
    # self.weight 的形状是 (out_features, in_features)
    # self.weight.T (转置) 的形状是 (in_features, out_features)
    # x 的形状是 (..., in_features)

    # x @ self.weight.T
    # (..., in_features) @ (in_features, out_features) -> (..., out_features)

    return torch.matmul(x, self.weight.T)
    1. … (省略号): 这是一个非常有用的通配符,意思是“匹配这里的所有其他维度”。

严格参考要求版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# 导入 PyTorch 库
import torch
import torch.nn as nn # nn = Neural Network (神经网络)
import math

# 1. class Linear(nn.Module):
# "我要定义一个新型号的'积木',它继承自 PyTorch
# 里最基础的'积木' (nn.Module)。"
#
# 继承 nn.Module 会“免费”给我们的类带来很多超能力:
# - 自动追踪所有“参数” (比如我们的 W)
# - 方便地在 CPU 和 GPU 间移动 (.to(device))
# - 方便地保存和加载 (.state_dict())
class Linear(nn.Module):

# 2. def __init__(...):
# "这是'积木'的'说明书'(构造函数)"
# "当别人要'制造'这块积木时 (例如 my_layer = Linear(2, 3)),
# 告诉 PyTorch 需要执行以下步骤。"
def __init__(self, in_features: int, out_features: int, device=None, dtype=None):

# 3. super().__init__()
# "必须的步骤!"
# "告诉'积木'的基座(父类) nn.Module:'我来报到了!请做好准备!'"
super().__init__()

# 4. self.in_features = ...
# "把传入的'输入数量'和'输出数量'记在小本本上,
# 后面可能会用到。"
self.in_features = in_features
self.out_features = out_features
factory_kwargs = {'device': device, 'dtype': dtype}

# 5. weight_tensor = torch.empty((in_features, out_features), ...)
# "关键一步!按照作业要求,我们要创建一个'架子' (Tensor)"
# Tensor 是 PyTorch 里的“数据容器”,你可以把它想象成
# 一个多维数组 (Numpy array)。
# torch.empty: 只分配内存,不管里面是啥“垃圾数据”。
# 形状 (shape): (in_features, out_features)
# 这就是作业要求的 W (不是 W.T)!
weight_tensor = torch.empty((in_features, out_features), **factory_kwargs)

# 6. self.W = nn.Parameter(weight_tensor)
# "!!! 深度学习的魔法核心 !!!"
# "我要给这个'架子'贴上一个'可学习'的标签 (nn.Parameter)。"
#
# 区别:
# - torch.Tensor: 只是一个普通的数据容器。
# - nn.Parameter: 是一个特殊的 Tensor。
#
# 这个 nn.Parameter 标签告诉 nn.Module (我们的'积木'基座):
# "嗨!我是'大脑'的一部分 (权重/旋钮),当模型开始'学习'
# (反向传播) 时,你必须计算我的梯度,并且'优化器'
# (Adam/SGD) 必须更新我!"
#
# (你把它命名为 self.weight 还是 self.W 都可以,
# 只要它是 nn.Parameter 就行。这里为了对应作业叫 self.W)
self.W = nn.Parameter(weight_tensor)

# 7. self.reset_parameters()
# "因为第5步的'架子'里装的是'垃圾数据',
# 我们现在需要用一些'有意义'的随机数来填满它。"
# 这叫“参数初始化”。
self.reset_parameters()

# 8. def reset_parameters(self):
# "这是一个辅助函数,用来做第7步的工作"
def reset_parameters(self):
# nn.init.trunc_normal_ 是一种初始化方法(截断正态分布)
# 意思是“用接近 0 的、比较均匀的随机数来填充 self.W”
# 好的初始化能让模型学得更快。
# _ (下划线) 在 PyTorch 中通常表示 "in-place" (原地修改)
# 也就是它会直接修改 self.W,而不是返回一个新张量。
with torch.no_grad(): # 初始化时不需要计算梯度
std = math.sqrt(1.0 / self.in_features)
nn.init.trunc_normal_(self.W, mean=0.0, std=std, a=-2*std, b=2*std)

# 9. def forward(self, x: torch.Tensor):
# "定义'数据流'"
# "这定义了当数据'流过'这块积木时,应该发生什么。"
# "当你写 output = my_layer(input) 时,PyTorch
# 会自动调用这个 my_layer.forward(input) 函数。"
def forward(self, x: torch.Tensor) -> torch.Tensor:

# 10. return x @ self.W
# "执行计算!"
# `@` 符号是 PyTorch 中 `torch.matmul` (矩阵乘法) 的简写。
#
# 检查一下维度:
# - x 的形状: (..., in_features)
# - self.W 的形状: (in_features, out_features)
#
# 矩阵乘法 `(..., in) @ (in, out)` 的结果:
# - 形状是 (..., out_features) —— 完全正确!
#
# 注意:我们根本不需要 .T (转置),因为我们存储的
# self.W (在第5步定义) 的形状就是我们想要的!
# 这就是作业要求您做的。
return x @ self.W

Embeddings

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""
As discussed above, the first layer of the Transformer is an embedding layer that maps integer token IDs
into a vector space of dimension d_model. We will implement a custom Embedding class that inherits from
torch.nn.Module (so you should not use nn.Embedding). The forward method should select the embedding
vector for each token ID by indexing into an embedding matrix of shape (vocab_size, d_model) using a
torch.LongTensor of token IDs with shape (batch_size, sequence_length).
"""

import torch
from torch import nn

class Embedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
'''
Construct an embedding module. This function should accept the following parameters:
- num_embeddings: int Size of the vocabulary
- embedding_dim: int Dimension of the embedding vectors, i.e., dmodel
- device: torch.device | None = None Device to store the parameters on
- dtype: torch.dtype | None = None Data type of the parameters
'''
super().__init__()
self.vocab_size = num_embeddings
self.d_model = embedding_dim
factory_kwargs = {"device": device, "dtype": dtype}
weight_shape = (num_embeddings, embedding_dim)
weight_tensor = torch.empty(weight_shape, **factory_kwargs)
self.weight = nn.Parameter(weight_tensor)

self.reset_parameters()

def reset_parameters(self) -> None:
with torch.no_grad():
nn.init.trunc_normal_(self.weight, std=1, a=-3, b=3)

def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
return self.weight[token_ids]

Pre-Norm Transformer Block

Pre-Norm Transformer Block 从输入嵌入到最终输出的路径中,会形成一条未经任何归一化处理的纯净 “残差流”,这被认为能优化梯度传播。如今,这种预归一化Transformer已成为语言模型的标准配置(如 GPT-3、LLaMA、PaLM等),因此我们将采用该架构。接下来我们将按顺序逐个解析预归一化Transformer模块的各个组件。

每个 Transformer 模块包含两个子层:

  1. 多头自注意力机制(Multi-Head Self-Attention)
  2. 位置级前馈网络(Position-wise Feed-Forward Network)。

我们采用​预归一化(pre-norm)结构​:在每个子层之前先进行层归一化。具体来说,若模块输入为 ,则模块执行如下操作:

  1. 自注意力子层​:

    $y = x + \mathrm{MultiHeadSelfAttention}(\mathrm{RMSNorm}(x))$

  2. 前馈网络子层​:

    $z = y + \mathrm{FFN}(\mathrm{RMSNorm}(y))$

$y = \mathrm{LayerNorm}(x + \mathrm{Sublayer}(x))$

  • 而Pre-Norm Transformer Block中

    $y = x + \mathrm{Sublayer}(\mathrm{RMSNorm}(x))$

数学区别:LayerNorm vs RMSNorm

假设我们有一个输入向量

$x = [x_1, x_2, \dots, x_n]$

(比如一个 token 的 d_model 维向量)。$\epsilon$ 是一个很小的常数(如 $1e-5$),防止除以零。LayerNorm (层归一化) 分四步完成:中心化缩放重缩放平移

  1. 计算均值 (μ):(中心化步骤)

    $\mu = \frac{1}{n} \sum_{i=1}^{n} x_i$

  2. 计算方差 (σ²):

    $\sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2$

  3. 归一化 (x̂):

    $\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$

  4. 缩放与平移 (y): LayerNorm 有两个可学习的参数:增益 $\gamma$ 和偏置 $\beta$。

    $y_i = \gamma \hat{x}_i + \beta$

RMSNorm (均方根归一化) 是一种简化的 LayerNorm。它移除了第 1 步(均值计算)和第 4 步中的平移 $\beta$

  1. 计算均方根 (RMS):(不减均值!)

    $
    RMS(x) = \sqrt{\frac{1}{n} \sum_{i=1}^{n} x_i^2}
    $

  2. 归一化 (x̂):

    $
    \hat{x}_i = \frac{x_i}{\sqrt{RMS(x)^2 + \epsilon}}
    $

(注:上面的

$
\sqrt{RMS(x)^2
$

只是为了和 $$\sqrt{\sigma^2$$ 在形式上对应,它就等于 $$RMS(x$$)
3. 缩放 (y): RMSNorm 只有一个可学习的参数:增益

$
amm
$

$
i = \gamma \hat{x}_i
$

Root Mean Square Layer Normalization

步骤 1:计算 $RMS(a)$ (均方根)

$RMS(a) = \sqrt{\frac{1}{d_{\text{model}}} \sum_{i=1}^{d_{\text{model}}} a_i^2 + \epsilon}$

  • $a$:一个 $d_{\text{model}}$ 维的激活向量(比如一个 token 的 embedding)。
  • $\sum a_i^2$:把向量中每个元素平方后​相加​​。
  • $\frac{1}{d_{\text{model}}} …$:取​**均值 (Mean)**​。
  • $\sqrt{…}$:取​**平方根 (Root)**​。
  • $\epsilon$: 只是为了防止 $RMS(a)$ 的计算结果为 0(这会导致除零错误)

这三个操作合起来就是 ​**Root Mean Square (均方根)​。它计算的是这个向量 $a$ 的​整体”强度”或”大小”**​。

步骤 2:归一化 (Normalize) 和缩放 (Rescale)

$\text{RMSNorm}(a_i) = \frac{a_i}{\text{RMS}(a)} g_i$

  • $\frac{a_i}{\text{RMS}(a)}$:这是归一化步骤。用向量的每个元素 $a_i$ 除以整个向量的”平均强度” $\text{RMS}(a)$。这会将向量的”强度”缩放回 1 附近。
  • $g_i$:这是一个​可学习的”增益”(gain) 参数​。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
Implement RMSNorm as a torch.nn.Module
"""

import torch
from torch import nn

class RMSNorm(nn.Module):

def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
"""
Construct the RMSNorm module. This function should accept the following parameters:
- d_model: int Hidden dimension of the model
- eps: float = 1e-5 Epsilon value for numerical stability
- device: torch.device | None = None Device to store the parameters on
- dtype: torch.dtype | None = None Data type of the parameters
"""
super().__init__()
self.d_model = d_model
self.eps = eps
fractory_kwargs = {"device": device, "dtype": dtype}

# 创建可学习的缩放权重参数,形状为 (d_model,),初始化为全 1
# 每个特征维度有一个对应的缩放因子
weight_shape = (d_model,)
weight_tensor = torch.ones(weight_shape, **fractory_kwargs) # g_i ​
self.weight = nn.Parameter(weight_tensor)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# 将输入提升到 float32 以保证归一化过程中的数值精度
input_dtype = x.dtype
x = x.to(dtype=torch.float32)

# 计算最后一个维度(特征维度)上的均方值(即方差,但不减均值)
# shape: (batch_size, sequence_length, 1)
variance = x.pow(2).mean(-1, keepdim=True)

# 使用均方根(RMS)进行归一化:x / sqrt(variance + eps)
# torch.rsqrt 是 1 / sqrt(x) 的高效实现
x = x * torch.rsqrt(variance + self.eps)

# 将归一化后的结果与可学习的权重相乘(逐通道缩放)
# 权重会自动广播到 batch 和 sequence 维度
x = self.weight * x

# 将结果转换回原始输入的数据类型(例如从 float32 回 float16)
return x.to(input_dtype)

Key notes:

  1. weight_tensor = torch.ones(weight_shape, **fractory_kwargs)# g_i

    1. torch.ones(...)全1初始化
    2. RMSNorm 层的 self.weight (即 $g_i$) 是一个”增益”或”缩放”因子。归一化步骤 x * rsqrt(...) 已经把向量的”强度”调整到了 1 附近。在训练刚开始时,我们最不希望的就是这个 $g_i$ 参数立即扭曲这个好不容易才归一化好的信号。因此做全1初始化
  2. input_dtype = x.dtype x = x.to(dtype=torch.float32)

    1. 模型的大部分计算可能在 float16bfloat16 下进行以节省显存和提高速度。但是,x.pow(2) 很容易在 float16 下​**溢出 (Overflow)**​(超过 65504)。
    2. 因此,这里:
      • x 安全地提升到 float32
      • float32 下完成所有敏感的归一化计算。
      • 最后再转换回原始的 input_dtype,以便下一层可以继续高效计算
  3. variance = x.pow(2).mean(-1, keepdim=True)x 形状: (B, S, D) -> (Batch, Sequence, d_model)

    1. .mean(-1, ...): 沿最后一个维度(-1)求均值。这就是

      $
      (1/d_{model}) * \sum(a_i^2)
      $

    2. keepdim=True: 保持维度,输出形状为 (B, S, 1), 这对于下一步的广播至关重要

  4. x = x * torch.rsqrt(variance + self.eps)

    1. x / torch.sqrt(variance + self.eps)x * torch.rsqrt(variance + self.eps)在数学上等价,但乘法和 rsqrt(倒数平方根)的组合通常比除法 sqrt 在 GPU 上执行得更快。
  5. x = self.weight * x

Position-Wise Feed-Forward Network 位置级前馈网络

在原始的Transformer论文(Vaswani等人[2017],第3.3节)中,Transformer的前馈网络(Feed-Forward Network, FFN)由两个线性变换组成,中间使用ReLU激活函数

$ReLU(x) = max(0, x)$

。通常情况下,内部前馈层的维度是输入维度的4倍。 然而,现代语言模型相较于这一原始设计引入了两个主要变化:使用了不同的激活函数,并采用了门控机制。

具体来说,我们将实现一种名为“SwiGLU”的激活函数,该函数已被诸如Llama 3 [Grattafiori et al., 2024] 和 Qwen 2.5 [Yang et al., 2024] 等大语言模型(LLM)所采用。

SwiGLU结合了SiLU(常被称为Swish)激活函数和一种称为门控线性单元(Gated Linear Unit, GLU)的门控机制。

此外,我们还将省略线性层中有时使用的偏置项(bias),这是自PaLM [Chowdhery et al., 2022] 和 LLaMA [Touvron et al., 2023] 以来大多数现代大语言模型的做法。

SiLU(或称Swish)激活函数 [Hendrycks 和 Gimpel, 2016; Elfwing 等, 2017] 定义如下:

$\text{SiLU}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$

如图所示,SiLU激活函数与ReLU激活函数类似,但在零点处是平滑的。

门控线性单元(GLU)最初由Dauphin等人[2017]提出,其定义为一个经过Sigmoid函数变换的线性变换与另一个线性变换之间的逐元素乘积:

$\text{GLU}(x, W_1, W_2) = \sigma(W_1 x) \odot W_2 x$

其中 $\odot$ 表示逐元素相乘。

$a \odot b$ (逐元素相乘) 的结果是:

$\begin{bmatrix} 1 \\ 2 \\ 3 \\ 4 \end{bmatrix} \odot \begin{bmatrix} 5 \\ 6 \\ 7 \\ 8 \end{bmatrix} = \begin{bmatrix} 1 \times 5 \\ 2 \times 6 \\ 3 \times 7 \\ 4 \times 8 \end{bmatrix} = \begin{bmatrix} 5 \\ 12 \\ 21 \\ 32 \end{bmatrix}$

“逐元素相乘” $\odot$ 是 GLU 实现其”门控”功能的​执行机制​。 它允许 $σ(W_1 x)$(门控)来动态地、逐个元素地控制 $W_2 x$(内容)中​有多少信息可以流向下一层​。

  • (Sigmoid 函数) 是这里的关键。Sigmoid 会把任何输入的数字都压缩到 ​0 到 1 之间​。这个分支不关心”内容”,只关心”​哪些内容是重要的​”。它负责为”内容分支”的每个元素学习一个”​通过系数​”。
  • Eg. $\text{result} = \text{gate} \odot \text{value}$

$\begin{bmatrix} 1 \\ 0 \\ 0.5 \end{bmatrix} \odot \begin{bmatrix} 100 \\ -50 \\ 0.5 \end{bmatrix} = \begin{bmatrix} 1 \times 100 \\ 0 \times -50 \\ 0.5 \times 0.5 \end{bmatrix} = \begin{bmatrix} 100 \\ 0 \\ 0.25 \end{bmatrix}$

门控线性单元被认为可以通过提供一条线性的梯度通路,同时保留非线性能力,从而“减轻深层架构中的梯度消失问题”。

将 SiLU/Swish 激活函数与 GLU 机制结合起来,就得到了 SwiGLU,我们将用它来构建前馈网络:

$\text{FFN}(x) = \text{SwiGLU}(x, W_1, W_2, W_3) = W_2 \left( \text{SiLU}(W_1 x) \odot W_3 x \right)$

其中 $x \in \mathbb{R}^{d_{\text{model}}}$

, $$W_1, W_3 \in \mathbb{R}^{d_{\text{ff}} \times d_{\text{model}}$$, $$W_2 \in \mathbb{R}^{d_{\text{model}} \times d_{\text{ff}}$$, 且通常设定 $$d_{\text{ff}} = \frac{8}{3} d_{\text{model}$$。 Shazeer [2020] 首次提出了将SiLU/Swish激活函数与GLU结合的思路,并通过实验表明,在语言建模任务上,SwiGLU的表现优于ReLU以及无门控的SiLU等基线方法。

在本作业的后续部分,你也将对SwiGLU和SiLU进行比较。

尽管我们已经提到了这些组件的一些启发式理由(相关论文也提供了更多支持性证据),但保持实证视角仍然很重要。Shazeer论文中有一句如今广为流传的话:

“我们并不解释为何这些架构似乎有效;我们将它们的成功归因于——如同其他一切一样——神的仁慈。”

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
from cs336_basics.linear import Linear

class SwiGLU(nn.Module):
"""
SwiGLU 激活函数实现:FFN = W2 * (SiLU(W1 x) ⊙ W3 x)
- x: d_model
- W1: d_ff x d_model
- W2: d_model x d_ff
- W3: d_ff x d_model
- 通常 d_ff = 8/3 * d_model
"""

def __init__(self, d_model, d_ff, device=None, dtype=None):
super().__init__()
self.d_model = d_model
self.d_ff = d_ff
factory_kwargs = {"device": device, "dtype": dtype}
# x @ self.W.T
self.w1 = Linear(d_model, d_ff, device=device, dtype=dtype) # x(d_model) @ W1.T(d_model, d_ff) -> (d_ff)
self.w2 = Linear(d_ff, d_model, device=device, dtype=dtype) # (SiLU(W1 x) ⊙ W3 x)(d_ff) @ W2.T(d_ff, d_model) -> (d_model)
self.w3 = Linear(d_model, d_ff, device=device, dtype=dtype) # x(d_model) @ W3.T(d_model, d_ff) -> (d_ff)
# self.w1 = nn.Linear(d_model, d_ff, bias=False) ​
# self.w2 = nn.Linear(d_ff, d_model, bias=False) ​
# self.w3 = nn.Linear(d_model, d_ff, bias=False) ​

def _silu(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x)

def _glu(self, x: torch.Tensor) -> torch.Tensor:
return self._silu(self.w1(x)) * self.w3(x)

def forward(self, x):
# SwiGLU: W2(SiLU(W1 x) ⊙ W3 x)
# return self.w2(self.w1(x).silu() * self.w3(x))
return self.w2(self._glu(x))

Rotary Positional Embeddings (RoPE)

sinusoidal positional encoding

正弦位置编码公式

设模型的隐藏维度为 $d$,位置为 $\text{pos}$,维度下标为 $i$,位置编码向量为 $PE(\text{pos})$:

$PE(\text{pos}, 2i) = \sin\left(\frac{\text{pos}}{10000^{\frac{2i}{d}}}\right)$

$PE(\text{pos}, 2i+1) = \cos\left(\frac{\text{pos}}{10000^{\frac{2i}{d}}}\right)$

也就是说:

  • 偶数维使用 $\sin$
  • 奇数维使用 $\cos$
  • 维度越大,周期越长 → 能编码长序列

为什么要用正弦和余弦

周期性: ​序列中的远距离位置差仍可通过周期关系表达出来。

平移不变性: ​任意两个位置的差可以通过角度差得到:

$\sin(a+b) = \sin(a)\cos(b) + \cos(a)\sin(b)$

因此模型能 ​直接从向量差中推断相对位置​,不需要额外学习。

构造一个简单示例

设序列长度为 4,隐藏维度 $d = 4$,位置 $\text{pos} = 0,1,2,3$,维度 $i = 0,1$,计算:

pos i 使用公式 计算结果
0 even(0) $\sin(0 / 10000^{0/4}) = \sin(0)=0$ →$PE(0,0)=0$
0 odd(1) $\cos(0 / 10000^{1/4}) = \cos(0)=1$ →$PE(0,1)=1$

对于 pos = 1:

$PE(1,0)=\sin(1)=0.84 \quad (\approx)$

$PE(1,1)=\cos(1)=0.54 \quad (\approx)$

维度 i=2,3 的项会使用更大的周期:

$PE(\text{pos},2) = \sin\left(\frac{\text{pos}}{10000^{\frac{2}{4}}}\right) = \sin\left(\frac{\text{pos}}{100}\right)$

$PE(\text{pos},3) = \cos\left(\frac{\text{pos}}{100}\right)$

最终得到的位置编码矩阵示例

对序列位置 0~3:

$PE = \begin{bmatrix} 0 & 1 & 0 & 1 \\ \sin(1) & \cos(1) & \sin(0.01) & \cos(0.01) \\ \sin(2) & \cos(2) & \sin(0.02) & \cos(0.02) \\ \sin(3) & \cos(3) & \sin(0.03) & \cos(0.03) \end{bmatrix}$

把位置编码向量加到 token embedding 上​:

$x_{\text{with pos}} = x_{\text{token}} + PE(\text{pos})$

这样做的问题:

  • 位置是加进去的,语义和位置信息混在一起,不可分离
  • Attention 中计算相似度时:$QK^T$ 无法直接表达 “相对位置差”

RoPE

为了将位置信息注入模型中,我们将实现 ​**旋转位置编码(RoPE)**​(Su 等,2021)。

对于一个 token 的 query 向量 $q^{(i)}$,它位于序列中的第 $i$ 个位置,维度为 $d$。 我们会对它施加一个成对旋转矩阵 $R^i$,得到:$q’^{(i)} = R^i q^{(i)} = R^i W_q x^{(i)}$

也就是说,矩阵 $R^i$ 会将 query 向量的每两个元素看作一个 ​二维向量​,并按角度 $\theta_{i,k}$ 进行旋转,其中:

$\theta_{i,k} = \frac{i}{\Theta^{(2k-2)/d}} \quad (k = 1, \ldots, d/2)$

$\Theta$ 是一个常数(一般取 10000,与 Transformer 的位置编码一致)。因此,矩阵 $R^i$ 可以看作一个 ​分块对角矩阵​,每个 2×2 小块为:

$$R^i_k= \begin{bmatrix} \cos(\theta_{i,k}) & -\sin(\theta_{i,k}) \\ \sin(\theta_{i,k}) & \cos(\theta_{i,k}) \end{bmatrix}$$

所以完整的旋转矩阵 $R^i$ 是:

$$R^i= \begin{bmatrix} R^i_1 & 0 & 0 & … & 0\\ 0 & R^i_2 & 0 & … & 0\\ 0 & 0 & R^i_3 & … & 0\\ … & … & … & … & …\\ 0 & 0 & 0 & … & R^i_{d/2} \end{bmatrix}$$

其中所有的 $0$ 都代表 ​2×2 的零矩阵​。

虽然我们可以显式构造整个 $d \times d$ 的旋转矩阵,但更高效的做法是利用其特性直接对向量进行旋转。而且,由于我们只关心一个序列内部 token 的​相对旋转关系​,所以所有层都可以​共享同一套 cos 和 sin 表​。 因此,这一层通常用 self.register_buffer(persistent=False) 来保存 cos 和 sin,而不会作为可训练参数。

最终,Q 和 K 都会用对应的 $R^i$ 进行旋转。 注意:这个层 ​没有可学习参数​。Code

  • theta: RoPE 中的 Θ 值(控制旋转角度的频率基底)
  • d_k: 查询(query)和键(key)向量的维度
  • max_seq_len: 输入序列的最大长度
  • device: 存储缓存张量的设备(torch.deviceNone

$
\theta_{i,k} = \frac{i}{\Theta^{(2k-2)/d}} \quad (k = 1, \ldots, d/2)
$

1
2
3
4
5
6
7
8
9
10
1.0 / (theta ** (torch.arange(0, d_k, 2, device=device).float() / d_k))
'''
(function) def arange(
start: Number,
end: Number,
step: Number,
) -> Tensor
'''
# torch.arange(0, d_k, 2).float()
# tensor([ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18.])
1
2
3
4
5
6
positions = torch.arange(max_seq_len, device=device)
sinusoids = torch.outer(positions, freqs) # outer product
# ------------------
# Shape of Frequency torch.Size([d_k // 2])
# Shape of Position torch.Size([max_seq_len])
# Shape of Sinusoids torch.Size([max_seq_len, d_k // 2])
$$0$$ $$1$$ $$2$$
$$\frac{1}{\Theta^{0/d}$$
$$\frac{1}{\Theta^{2/d}}$$
$$\frac{1}{\Theta^{4/d}}$$
$$..$$
$$\frac{1}{\Theta^{(d-2)/d}}$$

可见最后得到的sinusoids矩阵就是

$
\theta_{i,k
$

1
2
3
# 缓存 cos 和 sin 编码,不参与训练
self.register_buffer("cos_cache", sinusoids.cos(), persistent=False)
self.register_buffer("sin_cache", sinusoids.sin(), persistent=False)
  • register_buffer(name, tensor) 会把这个 tensor 作为模块的一部分存储下来 ​但不作为可训练参数​。
  • ​*persistent*​=False​ 表示不要将这些 buffer 写入最终模型文件​(比如用 .save_pretrained() 时)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
cos = self.cos_cache[token_positions]
sin = self.sin_cache[token_positions]

# 分割输入为偶数和奇数维度: x = [x_even, x_odd]
x_even = x[..., 0::2] # 偶数索引: 0, 2, 4, ...
x_odd = x[..., 1::2] # 奇数索引: 1, 3, 5, ...

# 应用旋转公式
out_even = x_even * cos - x_odd * sin
out_odd = x_even * sin + x_odd * cos

# 交错合并: 将 (even, odd) 沿最后一维堆叠并展平
out = torch.stack([out_even, out_odd], dim=-1) # shape: (*, seq_len, d_k//2, 2)
out = out.flatten(-2) # shape: (*, seq_len, d_k)

# eg.
# out_even = [10, 20, 30] # 来自偶数位 (0,2,4)
# out_odd = [ 1, 2, 3] # 来自奇数位 (1,3,5)
# out = torch.stack([out_even, out_odd], dim=-1)
# 结果形状: (3, 2)
# 具体值:
# out =
# [[10, 1],
# [20, 2],
# [30, 3]]
# 等价关系:out[i, 0] = out_even[i], out[i, 1] = out_odd[i]

# out.flatten(-2) → [10, 1, 20, 2, 30, 3]
# 注意是交错的:even0, odd0, even1, odd1, even2, odd2
  • x:要旋转的向量(如 Q 或 K)。常见形状:(batch, seq_len, n_heads, d_k);你这段代码里只用到了​最后一维 d_k​,前面的维度用 ... 代表(不关心但会保留)。
  • cos_cache/sin_cache:预计算好的 cos(position*freq) / sin(position*freq),形状通常是 (max_seq_len, d_k//2)
  • token_positions:每个 token 的​位置索引​,形状通常是 (batch, seq_len)(或和你的 x 的第 0、1 维对齐)。

RoPE(旋转位置编码)数值例子推导(可计算)

本例展示 RoPE 如何让注意力分数只依赖 相对位置差 (q - p)。

设定参数

  • 向量维度:$d = 4$ 分成两组二维向量:$(x_1, x_2), (x_3, x_4)$
  • RoPE 旋转频率:
    • 第一组角度:$\theta_{1}(i) = i$
    • 第二组角度:$\theta_{2}(i) = i / 100$
  • 选两个 token 的位置:
    • Query 在 $p = 2$
    • Key 在 $q = 5$
  • 相对位置差:${\Delta} = q - p = 3$
  • 原始向量(未旋转):
    • Query:$(1, 0, 2, 0)$
    • Key:$(0, 1, 0, 3)$

Step 1:计算旋转角

分组 Query p=2 Key q=5 相对角度差 Δθ
第一组 $\theta_{1}(2) = 2$ $\theta_{1}(5) = 5$ $3$
第二组 $\theta_{2}(2) = 0.02$ $\theta_{2}(5) = 0.05$ $0.03$

Step 2:对每组二维向量进行旋转

二维旋转公式:

$(u,v) \rightarrow (u\cos\theta - v\sin\theta,; u\sin\theta + v\cos\theta)$

旋转 Query(p=2):

原向量 角度 结果
$(1, 0)$ $\theta = 2$ $(-0.4161,; 0.9093)$
$(2, 0)$ $\theta = 0.02$ $(1.9996,; 0.0400)$

$Q_{p} \approx (-0.4161,; 0.9093,; 1.9996,; 0.0400)$

旋转 Key(q=5):

原向量 角度 结果
$$(0, 1$$ $$\theta = $$ $$(0.9589,\; 0.2837$$
$$(0, 3$$ $$\theta = 0.0$$ $$(-0.1499,\; 2.9963$$

$
K_{q} \approx (0.9589,; 0.2837,; -0.1499,; 2.9963)
$

Step 3:直接计算注意力分数(点积)

$
Q_{p}^{\top} K_{q} \approx -0.321093
$

Step 4:只用“相对位置差”旋转 Key 来重算

使用相对角度:

分组 Δθ
第一组 $$$$
第二组 $$0.0$$

旋转后 Key:

$
y’ \approx (-0.1411,; -0.9900,; -0.0900,; 2.9987)
$

点积:

$
x^{\top} y’ \approx -0.321093
$

与 Step 3 ​完全一致​。

关键结论: ​

$(R_{\theta(p)} x)^{\top} (R_{\theta(q)} y) = x^{\top} R_{\theta(q - p)} y$

即注意力分数只取决于:$q - p$

Scaled Dot-Product Attention

Softmax

$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
"""
Applies the softmax function along the specified dimension.
exp(x - max(x)) / sum(exp(x - max(x)))
Args:
x (torch.Tensor): Input tensor.
dim (int): Dimension along which to apply softmax.

Returns:
torch.Tensor: Tensor after applying softmax.
"""

# x: 任意形状的张量
# dim: 在哪个维度上进行 softmax(比如 dim=-1 表示对最后一维做 softmax)

# 1) 在指定维度上取最大值,用来做数值稳定性处理
# keepdim=True 的作用是保持维度不变,方便后续做广播减法
x_max = x.max(dim=dim, keepdim=True)[0] # .max 返回 (values, indices),所以取 [0] 获得最大值

# 2) 减去最大值后再做 exp,避免 exp 时数值溢出(即变得特别大)
x_exp = torch.exp(x - x_max)

# 3) 对每个向量求和,然后归一化得到 softmax 概率分布
# 同样使用 keepdim=True 保持维度不变,确保能正确广播除法
return x_exp / x_exp.sum(dim=dim, keepdim=True)

为了确保稳定性:

Note that $exp(v_i)$ can become inf for large values (then, $inf/inf = NaN$).

We can avoid this by noticing that the softmax operation is invariant to adding any constant to all inputs.

We can leverage this property for numerical stability—typically, we will subtract the largest entry of $o_i$ from all elements of $o$, making the new largest entry $0$. You will now implement softmax, using this trick for numerical stability.如果 x 里有很大的数,例如 100、200,$e^{200}$ 会非常大,接近 ​无穷大​,会导致:

  • 数值溢出(overflow)
  • 结果变成 $nan$ 或 $inf$
  • 模型梯度不稳定

softmax ​只在乎相对差值,不在乎绝对值​。

看下面这个数学变换:

令 $c = \max(x)$,那么:

$\text{softmax}(x_i - c) = \frac{e^{x_i - c}}{\sum_j e^{x_j - c}} = \frac{e^{x_i} \cdot e^{-c}}{\sum_j e^{x_j} \cdot e^{-c}}$

注意到分子分母都有 $e^{-c}$,可以抵消:$= \frac{e^{x_i}}{\sum_j e^{x_j}}$

SDPA

$
\text{Attention}(Q,K,V)=\operatorname{softmax}\Big(\frac{QK^\top}{\sqrt{d_k}}\Big)V
$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn

class ScaledDotProductAttention(nn.Module):
def __init__(self, device=None):
super().__init__()
self.device = device

def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
计算 Scaled Dot-Product Attention。

Args:
Q: 查询张量,shape (*, seq_len_q, d_k)
K: 键张量,shape (*, seq_len_kv, d_k)
V: 值张量,shape (*, seq_len_kv, d_v)
mask: 布尔掩码,shape (seq_len_q, seq_len_kv) 或可广播形状,True 表示保留

Returns:
输出张量,shape (*, seq_len_q, d_v)
"""
d_k = Q.shape[-1]

score = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
# K的最后两维进行交换 这样Q和K才能相乘

if mask is not None:
score = score.masked_fill(~mask, -1e9)
# torch.masked_fill(tensor, mask, value): ​
# 将 tensor 中 mask 为 True 的位置替换为 value
# ~mask: 取反操作,将 True 变为 False,False 变为 True
# 原来True表示保留, False的位置需要去掉 因此先取反 把要替换的变为True

from cs336_basics.softmax import softmax
attn_weights = softmax(score, dim=-1)

return torch.matmul(attn_weights, V)

Causal Multi-Head Self-Attention

We will implement multi-head self-attention as described in section 3.2.2 of Vaswani et al. (2017). Recall that, mathematically, the operation of applying multi-head attention is defined as follows:

$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)$

$\text{head}_i = \text{Attention}(Q_i, K_i, V_i)$

with $Q_i, K_i, V_i$ being slice number $i \in {1, \ldots, h}$ of size $d_k$ or $d_v$ of the embedding dimension for $Q$, $K$, and $V$ respectively. With Attention being the scaled dot-product attention operation defined in §3.5.4. From this we can form the multi-head self-attention operation:

$\text{MultiHeadSelfAttention}(x) = W_O \cdot \text{MultiHead}(W_Q x, W_K x, W_V x)$

Here, the learnable parameters are:

$W_Q \in \mathbb{R}^{h d_k \times d_{\text{model}}}, \quad W_K \in \mathbb{R}^{h d_k \times d_{\text{model}}}, \quad W_V \in \mathbb{R}^{h d_v \times d_{\text{model}}}, \quad W_O \in \mathbb{R}^{d_{\text{model}} \times h d_v}$

Since the $Q$s, $K$s, and $V$s are sliced in the multi-head attention operation, we can think of $W_Q, W_K$ and $W_V$ as being separated for each head along the output dimension. When you have this working, you should be computing the key, value, and query projections in a total of three matrix multiplies.

Causal masking. Your implementation should prevent the model from attending to future tokens in the sequence. In other words, if the model is given a token sequence $t_1, \ldots, t_n$ and we want to calculate the next-word predictions for the prefix $t_1, \ldots, t_i \quad (i < n)$ the model should not be able to access (attend to) the token representations at positions $t_{i+1}, \ldots, t_n$ since it will not have access to these tokens when generating text during inference (and these future tokens leak information about the identity of the true next word, trivializing the language modeling pre-training objective).

For an input token sequence $t_1, \ldots, t_n$, we can naively prevent access to future tokens by running multi-head self-attention $n$ times (for the $n$ unique prefixes in the sequence). Instead, we’ll use causal attention masking, which allows token $i$ to attend to all positions $j \le i$ in the sequence.You can use torch.triu or a broadcasted index comparison to construct this mask, and you should take advantage of the fact that your scaled dot-product attention implementation from §3.5.4 already supports attention masking.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn

from cs336_basics.linear import Linear
from cs336_basics.positionwise_feedforward import SwiGLU
from cs336_basics.rmsnorm import RMSNorm
from cs336_basics.multihead_self_attention import MultiHeadSelfAttention
from cs336_basics.rope import Rope

class TransformerBlock(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
max_seq_len: int,
theta: float,
device: str = None,
):
super().__init__()

self.d_model = d_model
self.num_heads = num_heads
self.d_ff = d_ff

rope = Rope(theta=theta, d_k=d_model // num_heads, max_seq_len=max_seq_len)
self.rms_norm1 = RMSNorm(d_model, device=device)
self.rms_norm2 = RMSNorm(d_model, device=device)
self.ffn = SwiGLU(d_model, d_ff, device=device)
self.attn = MultiHeadSelfAttention(d_model, num_heads, rope)

def forward(self, x: torch.Tensor) -> torch.Tensor:
y = self.attn(self.rms_norm1(x)) + x # 残差连接
z = self.ffn(self.rms_norm2(y)) + y

return z

The Full Transformer LM

步骤 模块 输入形状 输出形状 说明 code
1 Token Embedding (B, L) (B, L, D) 词 ID → 词向量 self.token_embedding = Embedding(vocab_size, self.d_model, **factory_kwargs)
2 N× TransformerBlock (B, L, D) (B, L, D) 注意力 + 前馈 + 残差 self.transformer_blocks = nn.ModuleList([`` TransformerBlock(`` d_model=self.d_model,`` num_heads=self.num_heads,`` d_ff=self.d_ff,`` max_seq_len=max_seq_len,`` theta=rope_theta,`` device=device`` )`` for _ in range(self.num_layers)`` ])
3 RMSNorm (B, L, D) (B, L, D) 正则化 self.norm = RMSNorm(d_model=self.d_model, **factory_kwargs)
4 Linear 输出头 (B, L, D) (B, L, V) 映射到词表 logits self.output_embedding = Linear(self.d_model, vocab_size, **factory_kwargs)
5 Softmax (B, L, V) same 得到概率

B = batch_size

L = seq_len

D = d_model

V = vocab_size

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch import nn
from cs336_basics.transformer_block import TransformerBlock
from cs336_basics.embedding import Embedding
from cs336_basics.rmsnorm import RMSNorm
from cs336_basics.linear import Linear
from cs336_basics.softmax import softmax

class TransformerLM(nn.Module):
def __init__(
self,
vocab_size,
context_length,
num_layers,
d_model: int,
num_heads: int,
d_ff: int,
rope_theta: float,
max_seq_len: int = 2048,
device: str = None,
dtype=None
):
super().__init__()
# self.vocab_size = vocab_size
self.context_length = context_length
self.num_layers = num_layers
self.d_model = d_model
self.num_heads = num_heads
self.d_ff = d_ff
factory_kwargs = {"device": device, "dtype": dtype}
# self.max_seq_len = max_seq_len
# self.rope_theta = rope_theta
# self.device = device
# self.dtype = dtype

self.token_embedding = Embedding(vocab_size, self.d_model, **factory_kwargs)
# self.transformer_block = TransformerBlock(
# d_model=self.d_model,
# num_heads=self.num_heads,
# d_ff=self.d_ff,
# max_seq_len=max_seq_len,
# theta=rope_theta,
# device=device
# )
self.transformer_blocks = nn.ModuleList([
TransformerBlock(
d_model=self.d_model,
num_heads=self.num_heads,
d_ff=self.d_ff,
max_seq_len=max_seq_len,
theta=rope_theta,
device=device
)
for _ in range(self.num_layers)
])
self.norm = RMSNorm(d_model=self.d_model, **factory_kwargs)
self.output_embedding = Linear(self.d_model, vocab_size, **factory_kwargs)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.token_embedding(x)
for block in self.transformer_blocks:
x = block(x)

x = self.norm(x)
x = self.output_embedding(x)

return x

Resource accounting

It is useful to be able to understand how the various parts of the Transformer consume compute and memory. We will go through the steps to do some basic “FLOPs accounting.”

The vast majority of FLOPS in a Transformer are matrix multiplies, so our core approach is simple:

  1. Write down all the matrix multiplies in a Transformer forward pass
  2. Convert each matrix multiply into FLOPs required

For this second step, the following facts will be useful:

Rule: Given $A \in \mathbb{R}^{m \times n}$ and $B \in \mathbb{R}^{n \times p}$, the matrix-matrix product $AB$ requires $2mnp$ FLOPs.

To see this, note that

$$(AB)[i, j] = A[i, :] \cdot B[:, j]$$

and that this dot product requires $n-1$ additions and $n$ multiplications (a total of $2n-1 \approx 2n$ FLOPs).

考虑矩阵乘法里单个元素的计算, 点积计算方式是:

$$A[i,1] \cdot B[1,j] + A[i,2] \cdot B[2,j] + \cdots + A[i,n] \cdot B[n,j]$$

可以看出:

但在 FLOPs 估算中,我们不区分 + 与 *,只按乘法 + 加法数量合计成本,并且通常把加法也算为 $n$ 次,方便估计。因此认为:一个点积 $\approx 2n$ FLOPs

Then, since the matrix-matrix product $AB$ has $m \times p$ entries, the total number of FLOPs is $(2n)(mp) = 2mnp$.

矩阵 $AB$ 的大小是:$m \times p$

也就是说,我们需要计算 $m \times p$ 个点积。每个点积成本是 $2n$,所以总 FLOPs:$(2n)(m \times p) = 2mnp$Now, before you do the next problem, it can be helpful to go through each component of your Transformer block and Transformer LM, and list out all the matrix multiplies and their associated FLOPs costs.

  1. 考虑 GPT-2 XL 模型,其配置如下:

    1. vocab_size: 50,257
    2. context_length: 1,024
    3. num_layers: 48
    4. d_model: 1,600
    5. num_heads: 25
    6. d_ff: 6,400
      假设我们使用该配置构建模型,该模型共有多少可训练参数?若每个参数以单精度浮点数(32位)存储,仅加载该模型需要多少内存?

    每层包含注意力和前馈网络参数,加上共享的词嵌入与输出层:

    $$\text{Params} = VD + N(4D^2 + 2DD_{ff} + 4D)$$

    其中:

    1. $V$:词嵌入(与输出层权重共享)
    2. $4D^2$:注意力 Q/K/V/O 投影
    3. $2DD_{ff}$:MLP 两层
    4. $4D$:两层 RMSNorm 的可训练参数(scale)
      代入数值:

    $
    \begin{aligned} \text{Params} &= 50257 \times 1600 + 48 \times (4 \times 1600^2 + 2 \times 1600 \times 6400 + 4 \times 1600) \ &= 80,411,200 + 48 \times (10,240,000 + 20,480,000 + 6,400) \ &= 80,411,200 + 48 \times 30,726,400 \ &= 80,411,200 + 1,474,867,200 \ &= 1,555,278,400 \end{aligned}
    $

内存需求(单精度):

每个参数占 4 字节,总内存:

$
1,555,278,400 \times 4 = 6,221,113,600 \text{ bytes} \approx 6.22 \text{ GB}
$

逐项解释如下(默认​权重共享的词嵌入/输出层​、忽略各处 bias——占比很小):

  1. 词嵌入(与输出层权重共享)

    1. 词表大小 $V$,隐层维度 $D$。
    2. 权重矩阵形状:$V \times D$,参数数目 $VD$。
    3. 因为权重绑定(tied weights),​只算一次​,而不是”嵌入 + 输出”各算一遍。
  2. ​**注意力投影 $Q/K/V/O$**​

    1. 每个都是线性映射:$\mathbb{R}^D \to \mathbb{R}^D$,权重矩阵 $D \times D$。
    2. 共 4 个:$W_Q, W_K, W_V, W_O$。
    3. 参数合计 $4D^2$。
    4. 这里 ​多头数不改变参数量级​:实现上通常是把 $W_Q$ 等做成 $(D \times D)$ 的”拼接版”,拆成多头只是​重排维度​,不是多复制权重。
  3. 前馈网络(MLP/FFN)

  4. 两层线性:$D \to D_{ff}$ 和 $D_{ff} \to D$,权重分别 $D \times D_{ff}$、$D_{ff} \times D$。

  5. 参数合计 $2DD_{ff}$。

  6. 激活(如 GeLU/SiLU)本身​不引入可训练参数​。

  7. 归一化(每层两次)

  8. 你写的是 RMSNorm 的 scale;如果确实是 RMSNorm(只有 scale),​**每层应是 $2D$**​(两次 RMSNorm)。

  9. 你公式里用了 $4D$,这相当于把每个 Norm 当作 LayerNorm(有 $\gamma, \beta$ 两个向量)来计数:每个 Norm $2D$,两次就是 $4D$。

  10. 两种写法差异很小:每层相差 $2D$,总共差 $2ND$ 个参数(对本例仅 15.36 万),对 15 亿量级近乎可以忽略。

若严格按 GPT-2 原论文,它用的是 ​LayerNorm​,那用 (4D) 是合理的;若实现真用 ​RMSNorm​,把 (4D) 改成 (2D) 更严谨。

Chapter #4 Training a Transformer LM

We now have the steps to preprocess the data (via tokenizer) and the model (Transformer). What remains is to build all of the code to support training. This consists of the following:

  • Loss​: we need to define the loss function (cross-entropy).
  • Optimizer​: we need to define the optimizer to minimize this loss (AdamW).
  • Training loop​: we need all the supporting infrastructure that loads data, saves checkpoints, andmanages training.

Cross-entropy loss

Recall that the Transformer language model defines a distribution

$
p_θ(x_{i+1}|x_{1:i})
$

for each sequence $$$$ of length $$m+ $$and $$i= 1, . . . , $$.

即预测下一个词

$
x_{i+1
$

在给定前文 $$x_{1:i$$ 的情况下的概率。

Given a training set consisting of sequences of length , we define the standard cross-entropy (negative log-likelihood) loss function:

$
\ell(\theta; D) = \frac{1}{|D| m} \sum_{x \in D} \sum_{i=1}^{m} -\log p_{\theta}(x_{i+1} \mid x_{1:i})
$

注意:Transformer 只需一次前向传播,就能同时得到序列中每个位置 的预测

$
p_\theta(x_{i+1} \mid x_{1:i}
$

  • $
    \ell(\theta; D
    $

    :整个模型的平均损失;

  • $
    |D
    $

    :训练集中样本数量;

  • :每个序列的长度;

  • $
    p_{\theta}(x_{i+1} \mid x_{1:i}
    $

    :模型预测的下一个 token 的概率;

  • 内层求和:对序列中每个位置 求负 log 概率;

  • 外层求和:对所有样本求平均。

$
p(x_{i+1} \mid x_{1:i}) = \text{softmax}(o_i)[x_{i+1}] = \frac{\exp(o_i[x_{i+1}])}{\sum_{a=1}^{\text{vocab_size}} \exp(o_i[a])}
$

具体来说,Transformer 在每个位置 会输出一个logits 向量

$
o_i \in \mathbb{R}^{\text{vocab_size}}
$

,其中:

  • 每个维度

    $
    o_i[a
    $

    对应词表中第 个 token 的“原始分数”;

  • softmax 将这些 logits 转为概率分布;

  • 真正的预测概率是 softmax 后,目标词

    $
    x_{i+1
    $

    所对应的分量。

在训练时, 我们有「真实的」token 序列

$
x_1, x_2, \ldots, x_
$

$
xi+
$

已经是已知的(来自训练数据),固定不变;我们只是在评估模型对真实 token 的概率;

即:

$
p(x_{i+1} | x_{1:i}) = \frac{e^{o_i[x_{i+1}]}}{\sum_{a=1}^{V} e^{o_i[a]}}
$

交叉熵损失在实现时,直接基于 logits 向量 $o_i$ 和目标 token $x_{i+1}$ 来定义。也就是说,在代码中不需要手动计算 softmax → log → loss 的全部过程,而是用内置的函数如:

1
loss = F.cross_entropy(o_i, target)

它内部自动执行 softmax + log + 求平均的步骤,并处理数值稳定性问题。

在实现交叉熵损失时,需要注意数值稳定性问题,就像实现 softmax 时一样。

比如,直接计算 $\exp(o_i)$ 可能会导致溢出,因此通常会减去最大值:

$$\text{softmax}(o_i) = \frac{\exp(o_i - \max(o_i))}{\sum_a \exp(o_i[a] - \max(o_i))}$$

从而保持稳定。

步骤 内容 数学公式 / 数值 说明
1. 模型输出 logits $o = [1.2, 0.9, -0.1, 2.0]$ Transformer 对每个词输出原始得分
2. softmax 概率 $p = [0.311, 0.230, 0.084, 0.693]$ $p[a] = \frac{e^{o[a]}}{\sum e^{o}}$ 转换为词表上的概率分布
3. 真实标签 “B” → 索引 1 表示目标是第 1 类
4. 交叉熵损失 $\ell = -\log p[\text{B}] = -\log(0.230) = 1.4$ 概率越小,损失越大
5. 改进预测 若 logits = [0.2, 3.5, -0.3, 1.1] → $p[\text{B}] = 0.915$ $\ell = -\log(0.915) = 0.08$ 预测更准,损失更小

交叉熵损失的设计本质是​最大似然估计(MLE)​的实现形式:它通过最小化真实标签的负对数概率,让模型学习在真实词上分配更高概率。数学上,交叉熵等价于最小化模型分布与真实分布之间的 ​KL 散度​,使模型预测尽可能接近真实数据分布。它的梯度形式简单、数值稳定、惩罚“自信但错误”的预测,是概率建模和分类任务中最自然、最有效的目标函数。

Code

1
2
3
from torch.nn import functional as F
def cross_entropy(inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
return F.cross_entropy(inputs, targets)

input

第 b 个样本、第 t 个时间步,对词表中第 v 个词的未归一化预测分数(logit)

1
inputs.shape == (batch_size, sequence_length, vocab_size)

targets

targets[b, t] 是第 b 个样本在第 t 个位置的 ​正确类别索引​(整数,范围 [0, vocab_size - 1]

1
targets.shape = (batch_size, sequence_length)

最后计算

$$\text{loss}_{b,t} = -\log\text{softmax}(inputs[b,t])[target[b,t]] = -\log\frac{e^{inputs[b,t,targets[b,t]]}}{\sum_v e^{inputs[b,t,v]}}$$

Perplexity

Perplexity 困惑度 = exp(average cross-entropy)

$$\text{perplexity} = \exp\left(\frac{1}{m}\sum_{i=1}^{m}\ell_i\right)$$

符号 含义
$m$ 序列长度(即 token 数)
$\ell_i$ 第 $i$ 个 token 的 交叉熵损失(negative log-likelihood)
$\frac{1}{m}\sum_i \ell_i$ 每个 token 的平均 cross-entropy
$\exp(\cdot)$ 对平均 cross-entropy 取指数,即转回”原空间”
  • 训练语言模型时,我们最小化 cross-entropy:

    $$\text{Loss} = \frac{1}{m}\sum_{i=1}^{m}\ell_i = -\frac{1}{m}\sum_{i=1}^{m}\log p(x_i | x_{1:i-1})$$

  • 困惑度(Perplexity)定义为该平均交叉熵的指数:

    $$\text{PPL} = e^{\text{Loss}}$$

  • 困惑度可以理解为模型在预测下一个 token 时,平均有多少个“可能的词”在它看来是合理的。

  • 换句话说:

    • 如果模型完美预测(loss → 0),则 perplexity → 1;
    • 如果模型对所有词平均分布(即非常不确定),则 perplexity ≈ 词表大小 V

例如:

  • vocab_size = 10,000;
  • 模型输出接近均匀分布 ⇒ 每个词概率约 1/10,000;
  • 那么平均 cross-entropy ≈ ln(10,000) = 9.21;
  • perplexity = exp(9.21) ≈ 10,000。

The SGD Optimizer

The simplest gradient-based optimizer is Stochastic Gradient Descent (SGD). We start with randomly initialized parameters $\theta_0$. Then for each step $t= 0, \ldots, T-1$, we perform the following update:

$$\theta_{t+1} \leftarrow \theta_t - \alpha \nabla L(\theta_t; B_t)$$

where $B_t$ is a random batch of data sampled from the dataset $D$, and the learning rate $\alpha_t$ and batch size $|B_t|$ are hyper parameters

Implementing SGD in PyTorch

这里实现

$$\theta_{t+1} \leftarrow \theta_t - \frac{\alpha}{\sqrt{t+1}} \nabla L(\theta_t; B_t)$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

class SGD(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3):
if lr<0:
raise ValueError(f"Invalid learning rate:{lr}")
defaults={"lr": lr}
super().__init__(params, defaults) # 将默认超参数(这里只有 lr)传给父类,父类会把 params 组织成 param_groups,以便多个参数组使用不同超参

def step(self, closure: Optional[Callable]=None):
loss = None if closure is None else closure() # 如果传入 closure,调用它并把返回的损失赋给 loss;否则 loss=None。
for group in self.param_groups:
lr = group["lr"] # Get the learning rate.
for p in group["params"]:
if p.grad is None:
continue # 跳过没有梯度的参数(例如冻结参数)

state = self.state[p]# Get state associated with p.
# self.state 是一个字典,把 p 映射到其状态字典(例如动量缓冲、迭代计数等)。如果之前没有,则 state 会是空字典 dict()(父类维护该结构)。
t = state.get("t", 0)
# Get iteration number from the state, or initial value.
# 从 state 里获取某个计数器 t(表示该参数已更新的次数),若不存在则默认 0。
grad = p.grad.data # Get the gradient of loss with respect to p.
p.data -= lr / math.sqrt(t + 1) * grad # Update weight tensor in-place.
state["t"] = t + 1 # Increment iteration number.
return loss

更好的版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math

class SGD(torch.optim.Optimizer):
def __init__(self, params: Iterable[torch.nn.Parameter], lr: float = 1e-3):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
defaults = {"lr": lr}
super().__init__(params, defaults)

def step(self, closure: Optional[Callable] = None):
"""Performs a single optimization step.
Uses lr / sqrt(t+1) scheduling per-parameter where t is the parameter's update count.
"""
loss = None
if closure is not None:
# closure should re-evaluate the model and return the loss.
loss = closure()

for group in self.param_groups:
lr = group["lr"]
for p in group["params"]:
if p.grad is None:
continue

# Ensure state dict exists
state = self.state[p]
if "t" not in state:
state["t"] = 0 # initialize iteration count

t = state["t"]
# compute scalar scaling factor (Python float)
scale = lr / math.sqrt(t + 1)

# safer update: do it under torch.no_grad() so autograd is not tracking.
with torch.no_grad():
# in-place subtract: p = p - scale * grad
# p.add_(-scale, p.grad) also works
p.add_(p.grad, alpha=-scale)

state["t"] = t + 1

return loss

Key notes:

  1. p.add_(p.grad, alpha=-scale)
    1. Tensor.add_(other, alpha=1): 将 other * alpha 加到当前张量上,就地更新
    2. Ps. Tensor.copy_(src): 将另一个张量 src 的内容复制到当前张量中(形状必须一致),同样是原地修改。

p​**p.data的区别, 以及使用with torch.no_grad()的区别**

  • 在优化器里,p 通常是一个 torch.nn.Parameter,它是 torch.Tensor 的子类, 也就是说:

    • Parameter本质上就是一个带标记的 Tensor​;它告诉 nn.Module:“这是需要被训练的权重变量”;
    • 因此它有:
      • .data 属性(一个真正存数据的 Tensor)
      • .grad 属性(存储反向传播得到的梯度)
      • .requires_grad=True
        | 属性 | 说明 |

| ——————————————– | ————————————————- |
| p.data | 与之共享数据存储的裸 Tensor(不追踪计算图) |
| p.grad | 反向传播得到的梯度 Tensor(与 p 形状相同) |
| p.requires_grad | 是否启用 autograd(通常为 True) |
| p.grad_fn | 指向产生它的计算节点(如果是叶子节点则为 None) |
| p.is_leaf | 是否是计算图的叶节点(Parameter 通常是 True) |
| p.device/p.dtype/p.shape | 数据类型、设备、形状 |

  • p.data 是 ​另一个 Tensor 对象​,它与 p 共享同一块内存​**(storage), ​但它 ​不被 autograd 图追踪**​。

    | 属性 | 说明 |

| —————————— | ———————- |
| .shape | 张量形状 |
| .dtype | 数据类型 |
| .device | 存储在哪个设备 |
| .requires_grad | 永远为 False |
| .grad | 永远为 None |
| .grad_fn | None(不在计算图中) |
| .is_leaf | True |
| .data_ptr() | 底层数据内存地址 |
| .add_(),.copy_() | 可以原地操作 |

  • 对比

    | 属性 | p | p.data |

| ———————— | ————————————– | ——————————————— |
| 类型 | torch.nn.Parameter (继承自 Tensor) | torch.Tensor |
| 是否被 Autograd 追踪 | ✅ 是(requires_grad=True) | ❌ 否 (永远不被追踪) |
| 数据存储(内存) | 相同 | 相同 |
| shape / dtype / device | 完全相同 | 完全相同 |
| 常见用途 | 模型参数参与计算图 | 临时绕过 autograd、直接改数据 |
| 危险点 | 修改 p 会被图记录 | 修改 p.data 会绕过 autograd,可能破坏计算图 |

  • .datawith torch.no_grad() 的区别

    | 特性 | .data | with torch.no_grad(): |

| ————————- | ——————- | —————————– |
| 是否创建新上下文 | 否 | 是 |
| 是否安全 | ❌ 容易破坏计算图 | ✅ 官方推荐 |
| 是否改变 requires_grad | 不会 | 不会 |
| 是否被 autograd 追踪 | 永远不会 | 块内所有操作都不追踪 |
| 是否会导致隐式错误 | 经常 | 基本不会 |

backward()step() 的角色分工

操作 作用 对象 是否构建计算图
loss.backward() 反向传播,计算梯度 各参数的.grad ✅ 构建/使用计算图
optimizer.step() 使用梯度更新参数 各参数的.datap ❌ 不参与计算图

AdamW

Adam (Adaptive Moment Estimation)

Adaptive Moment Estimation: “对梯度的一阶矩(平均)和二阶矩(方差)进行自适应估计”

符号 含义 对应统计量
$$m_$$ 一阶矩估计 梯度的动量(平均值)
$$v_$$ 二阶矩估计 梯度的方差(平方的平均值)

步骤:

  1. 计算梯度

    $
    g_t = \nabla_{\theta_t} \ell(\theta_t)
    $

  2. 更新一阶矩(动量)

    1. $
      m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t
      $

    2. 含义: 对梯度求指数加权平均. 若

      $
      \beta_1 = 0.
      $

      ,则最近的梯度权重大,过去的影响逐渐衰减

  3. 更新二阶矩(平方梯度平均)

    1. $
      v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
      $
    2. 含义: 追踪梯度的方差, 用于自适应调整每个参数的学习率. 例如某个参数的梯度波动大, 就让它学习率更小
  4. 偏差校正(bias correction)因为

    $
    m_0 = v_0 =
    $

    ,在训练初期会被低估。于是引入校正项:

    1. $
      \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}
      $

    2. 这些分母接近 1 时(如

      $
      t \to \infty
      $

      ),就不再影响结果。

  5. 参数更新

  6. $
    \theta_{t+1} = \theta_t - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
    $

  7. 含义:

  * 分母部分调节学习率(variance-based scaling);
  * 分子是平滑过的动量;
  * 每个参数维度都有自己的学习率。
    | 部分                                     | 直觉解释                                                     |

| —————————————— | ————————————————————– |
| $$m_$$ | 相当于“速度”,平滑方向;防止震荡。 |
| $$v_$$ | 表示梯度变化幅度,抑制波动大的方向。 |
| $$\hat{m}_t / \sqrt{\hat{v}_t$$ | 用平滑的梯度方向除以变化程度,使每个参数的步伐适应性地缩放。 |
| $$\epsilo$$ | 防止除 0,稳定性项。 |

AdamW

AdamW = Adam + 正确实现的 weight decay

Adam+权重衰减: ​把权重衰减(L2 正则)加在梯度里。 → weight decay 会被梯度方差

$
v_t
$

缩放,衰减力度不一致。AdamW: 把权重衰减与梯度更新​解耦​,单独在参数更新后执行。 → 纯粹地让参数按固定比例缩小,正则化效果更稳定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import math
import torch

class AdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, weight_decay=0.0, betas=(0.9, 0.999), eps=1e-8):
defaults = {
"lr": lr,
"weight_decay": weight_decay,
"betas": betas,
"eps": eps
}
super().__init__(params, defaults)

def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
lr = group['lr']
weight_decay = group['weight_decay']
beta1, beta2 = group['betas']
eps = group['eps']

for p in group['params']:
if p.grad is None:
continue
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')

# 1. 得到梯度g
grad = p.grad.data
state = self.state[p]

# state init
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p.data)
state["exp_avg_sq"] = torch.zeros_like(p.data)


# 2. 得到一阶矩m
# 3. 得到二阶矩v
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]

# update biased first and second moment estimates
# exp_avg.mul_(beta1).add_(grad, alpha=(1 - beta1))
# exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))

new_exp_avg = exp_avg * beta1 + grad * (1.0 - beta1)
new_exp_avg_sq = exp_avg_sq * beta2 + (grad * grad) * (1.0 - beta2)

# write back state with new tensors (non-inplace)
state["exp_avg"] = new_exp_avg
state["exp_avg_sq"] = new_exp_avg_sq

# increment step and compute bias-corrected learning rate
state["step"] += 1
step = state["step"]

# 4. 计算偏差修正后的学习率
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step

# compute step_size as in original Adam: lr * sqrt(1 - beta2^t) / (1 - beta1^t)
step_size = lr * (math.sqrt(bias_correction2) / bias_correction1)

# 5. 更新参数
denom = torch.sqrt(new_exp_avg_sq) + eps
update = new_exp_avg / denom
# parameter update: θ = θ - step_size * exp_avg / denom
new_p_data = p.data - step_size * update

# 6. 权重衰减
# decoupled weight decay (AdamW): θ = θ - lr * weight_decay * θ
if weight_decay != 0:
new_p_data = new_p_data - lr * weight_decay * p.data

# write back parameter with new data
p.data = new_p_data

return loss

PyTorch 官方文档对 .data 的态度是:

“仅在实现优化器或底层库代码时使用。 普通训练逻辑中请使用 with torch.no_grad()。”

Learning rate scheduling

在训练过程中,能使损失值最快下降的学习率通常会随着训练进展而变化。 在训练 Transformer 时,常常使用一种​**学习率调度(learning rate schedule)**​: 我们从一个较大的学习率开始,使模型在初期能更快地更新; 随后逐渐减小学习率,使得在训练后期更新更加平稳。

在本次作业中,我们将实现 LLaMA(Touvron 等人, 2023) 所采用的​**余弦退火调度(cosine annealing schedule)**​。

一个调度器(scheduler)本质上就是一个函数:

它接受当前的训练步数 以及其他相关参数(例如初始与最终学习率),

然后返回在第 步时应该使用的学习率。

最简单的调度器是常数函数,也就是不论 为多少,始终返回同一个学习率。

余弦退火(cosine annealing)学习率调度包括以下参数:

(i) 当前迭代步 ,

(ii) 最大学习率$\alpha_{\max}$, (iii) 最小(最终)学习率

$
\alpha_{\min}
$

, (iv) 预热(warm-up)步数

$
T_
$

, (v) 余弦退火的总迭代步数

$
T_
$

。在第 ( t ) 步时的学习率

$
\alpha_
$

定义如下:

  • **预热阶段 Warm-up)**如果

    $
    t < T_
    $

    ,则

    $
    \alpha_t = \frac{t}{T_w} \alpha_{\max}
    $

  • **余弦退火阶段 Cosine annealing)**如果

    $
    T_w \le t \le T_
    $

    ,则

    $
    \alpha_t = \alpha_{\min} + \frac{1}{2} \left( 1 + \cos\left(\frac{t - T_w}{T_c - T_w} \pi\right) \right) (\alpha_{\max} - \alpha_{\min})
    $

  • **退火结束阶段 Post-annealing)**如果

    $
    t > T_
    $

    ,则

    $
    \alpha_t = \alpha_{\min
    $

1
2
3
4
5
6
7
8
import math

def learning_rate_schedule(it, max_learning_rate, min_learning_rate, warmup_iters, cosine_cycle_iters):
if it < warmup_iters:
return max_learning_rate * it / warmup_iters
if warmup_iters <= it <= cosine_cycle_iters:
return min_learning_rate + (max_learning_rate - min_learning_rate) * 0.5 * (1 + math.cos(math.pi * (it - warmup_iters) / (cosine_cycle_iters - warmup_iters)))
return min_learning_rate

Gradient clipping

在训练过程中,我们有时会遇到某些训练样本导致梯度过大的情况,这会使训练过程不稳定。 为了解决这个问题,实践中常用的一种技术叫做​梯度裁剪(gradient clipping)​。 其核心思想是在每次反向传播之后、优化器更新参数之前,​对梯度的范数设置上限​。

  • 设所有参数的梯度为 ,我们计算它的

    $
    \ell_
    $

    -范数

    $
    | g_i |2 = \sqrt{\sum_j g{i,j}^2}
    $

  • 如果该范数小于最大允许值 ,则保持 不变;

  • 否则,将 按比例缩小,缩放因子为:

    $
    \frac{M}{| g |_2 + \epsilon}
    $

其中

$
\epsilo
$

是一个很小的常数(例如 $$10^{-6$$),用于​数值稳定性​。这样裁剪之后的梯度,其范数将略小于 $$$$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
Args:
parameters (Iterable[torch.nn.Parameter]): collection of trainable parameters.
max_l2_norm (float): a positive value containing the maximum l2-norm.
"""
def gradient_clipping(parameters, max_l2_norm, eps=1e-6):
total_norm = 0.0
for p in parameters:
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5

# 计算缩放系数
clip_coef = max_l2_norm / (total_norm + eps)

# 如果 norm 超过上限,则缩放梯度
if clip_coef < 1.0:
for p in parameters:
if p.grad is not None:
p.grad.data.mul_(clip_coef) # 就地修改梯度

其中:

步骤 数学意义 PyTorch 实现
计算每个梯度的 L2 范数
$$\ g_i \ _2 = \sqrt{\sum_j g_{i,j}^2}$$
累加平方和 $$\sum_i \ g_i \
求整体 L2 范数 $$\ g\

假设我们的模型有 2 个参数:

  • 参数

    $
    p_
    $

    的梯度为 tensor([3.0, 4.0])

  • 参数

    $
    p_
    $

    的梯度为 tensor([1.0, 2.0, 2.0])

那我们手动算一下:

$
| g_1 |_2 = \sqrt{3^2 + 4^2} =
$

$
| g_2 |_2 = \sqrt{1^2 + 2^2 + 2^2} =
$

所以整体梯度范数为:

$
|g |_2 = \sqrt{5^2 + 3^2} = \sqrt{34} ≈ 5.830
$

代码中计算过程对应:

1
2
3
4
total_norm = 0
total_norm += 5**2 # 来自第一个参数
total_norm += 3**2 # 来自第二个参数
total_norm = sqrt(total_norm) # sqrt(34)

Chapter #5 Training Loop

Data Loader

被分词后的数据(例如你在 tokenizer_experiments 中准备的那个)是一个​单一的 token 序列

$
x = (x_1, x_2, …, x_n)
$

即使原始数据可能由多个独立文档组成(例如,不同的网页、页面或源代码文件),一种常见的做法是:​把所有文档的 token 拼接成一个大的序列​,并在它们之间加入一个分隔符(比如 <|endoftext|> token)。接着,一个 data loader(数据加载器) 会把这个大序列转化为一个个 ​**批次(batches)**​。

每个 batch 包含 ​B 个序列​,每个序列的长度为 ​m​,并且这些序列都有各自的下一个 token 序列(同样长度为 m)作为预测目标。

例如,当

$
B = 1, \quad m =
$

时,$$([x_2, x_3, x_4], [x_3, x_4, x_5])$$就是一个可能的训练样本。以这种方式加载数据可以简化训练,原因如下:

  1. 对于任何

    $
    1 \le i < n - m
    $

    ,都能生成一个合法的训练序列,因此采样非常简单。

  2. 所有训练序列的长度都一样,所以不需要对序列进行填充(padding),这提高了硬件利用率(同时还能增加 batch size B)。

  3. 最后,我们也不需要一次性把整个数据集都加载进内存来采样训练数据,这让我们可以轻松处理那些无法完全装入内存的大型数据集。

概念解释与思路讲解:

  1. 「tokenized data」是什么?

经过 分词器(tokenizer) 处理后的文本,不再是文字,而是整数序列。 例如文本:

“The cat sat on the mat.” 可能变成: [1012, 234, 345, 87, 2001, 9]

这就是 token 序列

$
x = (x_1, …, x_n
$


2. 「拼接所有文档」的原因

如果我们有很多独立的文件(例如网页、书页、代码文件),每个都太短。

若每次训练都从头单独加载,就会打乱 GPU 的高效训练流程。

所以一种通用做法是:

直接把所有 token 串在一起形成一个长序列 再在文档间插入特殊符号 <|endoftext|> 来表示“文档结束”。

这样模型看到的就是一条连续的长序列,训练就更高效了。

  1. 「data loader」的工作

数据加载器会把这条大序列切成许多小的片段来做训练样本。 每个样本长度为 ​m​(上下文长度)。

例如:

1
2
大序列: [x1, x2, x3, x4, x5]
m = 3

那么可能的训练输入输出对就是:

这样模型就学会:

给定前 m 个 token,预测下一个 token。

  1. 「不需要 padding」

因为每个样本长度都固定为 m,

所以不用在 batch 中对齐不同长度的句子(不像 BERT 那样)。

这节省内存、提升 GPU 并行效率。

  1. 「不需要一次加载整个数据集」

大部分实现只需要一个指针在大 token 序列上滑动即可。

可以流式(streaming)地生成训练样本,

不需要把所有数据一次性载入内存。

这就是大模型(如 GPT)能训练超大语料的关键技巧之一。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import numpy as np

def get_batch(dataset: torch.Tensor, batch_size: int, context_length: int, device: str = None) -> tuple[torch.Tensor, torch.Tensor]:
"""
Given a dataset (a 1D numpy array of integers) and a desired batch size and
context length, sample language modeling input sequences and their corresponding
labels from the dataset.

Args:
dataset (np.array): 1D numpy array of integer token IDs in the dataset.
batch_size (int): Desired batch size to sample.
context_length (int): Desired context length of each sampled example.
device (str): PyTorch device string (e.g., 'cpu' or 'cuda:0') indicating the device
to place the sampled input sequences and labels on.

Returns:
Tuple of torch.LongTensors of shape (batch_size, context_length). The first tuple item
is the sampled input sequences, and the second tuple item is the corresponding
language modeling labels.
"""
dataset_length = dataset.shape[0]
assert dataset_length >= context_length, "Dataset length must be greater than or equal to context length"
start_idx = np.random.randint(0, dataset_length - context_length, size=batch_size)
inputs = np.stack([dataset[s: s + context_length] for s in start_idx], dtype=np.int64)
targets = np.stack([dataset[s + 1: s + context_length + 1] for s in start_idx], dtype=np.int64)
return torch.from_numpy(inputs).to(device), torch.from_numpy(targets).to(device)

Using Memory-Map

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch

def get_batch_from_memmap(path: str, batch_size: int, context_length: int, device: str = None):
# path: .npy file path
data = np.load(path, mmap_mode='r') # memory-mapped, does not load all into RAM
dataset_length = data.shape[0]
assert dataset_length >= context_length, "dataset length must >= context_length"

max_start = dataset_length - context_length + 1 # inclusive last start
# 当 dataset_length == context_length 时,max_start == 1, randint(0,1) -> 0 合法
start_idx = np.random.randint(0, max_start, size=batch_size)

# 预先分配 numpy arrays(避免在 loop 中多次小分配)
inputs = np.empty((batch_size, context_length), dtype=np.int64)
targets = np.empty((batch_size, context_length), dtype=np.int64)

for i, s in enumerate(start_idx):
inputs[i] = data[s: s + context_length]
targets[i] = data[s + 1: s + context_length + 1]

inputs_t = torch.from_numpy(inputs).long()
targets_t = torch.from_numpy(targets).long()

if device is not None:
inputs_t = inputs_t.to(device)
targets_t = targets_t.to(device)

return inputs_t, targets_t

Checkpointing

除了加载数据之外,我们在训练过程中也需要​**保存模型(save models)**​。

在运行训练任务时,我们通常希望能够在训练因某种原因中断后(例如任务时间到期、机器故障等)​恢复训练​。

即使训练过程一切顺利,我们有时也希望能够保存训练中间阶段的模型(例如,为了在事后研究训练动态、在不同阶段抽样模型等)。

一个 checkpoint(检查点) 应该包含我们恢复训练所需的所有状态。

我们当然希望至少能恢复模型的权重(weights)。

如果使用的是​**有状态优化器(stateful optimizer)**​(例如 AdamW),我们还需要保存优化器的状态(例如在 AdamW 中的动量估计)。

此外,为了恢复学习率调度(learning rate schedule),我们还需要知道上一次训练停止时的迭代步数(iteration number)。

PyTorch 提供了很方便的机制来保存这些内容:

  • 每个 nn.Module 都有一个 state_dict() 方法,该方法返回一个字典,里面包含所有可学习参数的权重
  • 我们可以稍后使用对应的 load_state_dict() 方法来恢复这些权重。
  • 同样的逻辑也适用于任何 nn.optim.Optimizer 对象。

最后,torch.save(obj, dest) 可以将一个对象(例如一个包含张量、整数等 Python 对象的字典)保存到一个文件路径或类文件对象中。

之后,我们可以使用 torch.load(src) 将其重新加载回内存中。

Training loop

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import torch
import numpy as np
import wandb
from tqdm import tqdm
from loguru import logger

from cs336_basics.module import *
from cs336_basics.utils import *
from cs336_basics.tokenizer import *
# from cs336_basics.data import ​

def train():
logger.add("/home/fdse/zjh/cs336/assignment1/assignment1-basics/cs336_basics/log/train_tinystories_v0.log", rotation="1 day", retention="7 days", level="INFO")
# rotation 自动切分,retention 过期删除,compression 自动压缩

# 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1. 初始化模型配置
model_config = {
"vocab_size": 10000, # 词汇表大小
"context_length": 256, # 上下文长度
"num_layers": 4, # Transformer Block数
"d_model": 512, # 嵌入空间维度
"num_heads": 16, # 注意力头数
"d_ff": 1344, # 前馈网络维度
"rope_theta": 10000, # RoPE参数
}

# 2. 初始化优化器配置
optimizer_config = {
"lr": 3e-4, # 学习率
"weight_decay": 1e-2, # 权重衰减
"betas": (0.9, 0.999), # AdamW的beta参数
"eps": 1e-8, # AdamW的epsilon参数
"max_norm": 1.0, # 梯度裁剪的最大范数
}

# 3. 初始化训练配置
train_config = {
"batch_size": 16, # 批次大小
"total_epochs": 0.5, # 训练轮数
"checkpoint_freq": 2000, # 每隔多少步保存一次检查点
"log_freq": 10, # 每隔多少步记录一次日志
"val_freq": 400, # 每隔多少步在验证集上评估
"val_batch_size": 16, # 验证时的批次大小
"val_batches": 20, # 验证时使用的批次数量
}

# 数据路径配置
data_paths = {
"train_data": "/home/fdse/zjh/cs336/assignment1/assignment1-basics/cs336_basics/data/tinystories/train_tokens.npy",
"val_data": "/home/fdse/zjh/cs336/assignment1/assignment1-basics/cs336_basics/data/tinystories/val_tokens.npy",
"checkpoint_dir": "/home/fdse/zjh/cs336/assignment1/assignment1-basics/cs336_basics/data/tinystories/checkpoints"
}

# 4. 初始化wandb
wandb.init(
project="cs336-assignment-1",
name="train_v1",
config={
"model": model_config,
"optimizer": optimizer_config,
"training": train_config,
}
)

# 5. 初始化模型
logger.info("Init Model")
model = TransformerLM(
vocab_size=model_config["vocab_size"],
context_length=model_config["context_length"],
num_layers=model_config["num_layers"],
d_model=model_config["d_model"],
num_heads=model_config["num_heads"],
d_ff=model_config["d_ff"],
rope_theta=model_config["rope_theta"],
)
model = model.to(device) # 将模型移动到设备上
logger.info(f"Model moved to device: {device}")

# 6. 初始化优化器
logger.info("Init Optimizer")

optimizer = AdamW(
# params=optimizer_config["params"],
model.parameters(),
lr=optimizer_config["lr"],
weight_decay=optimizer_config["weight_decay"],
betas=optimizer_config["betas"],
eps=optimizer_config["eps"]
)

# 7. 初始化训练数据集
logger.info("Init Training Dataset")

# 方法1: 如果数据已经被tokenized并保存为 .npy 文件
# 使用内存映射方式加载,适合大型数据集
# training_dataset = np.load("./data/train_tokens.npy", mmap_mode='r')

# 方法2: 使用 memmap (更推荐,内存效率更高)
# training_dataset = np.memmap("./data/train_tokens.npy", dtype=np.int64, mode='r')

training_dataset = np.load(data_paths['train_data'], mmap_mode='r')
logger.info(f"Training dataset loaded, size: {len(training_dataset)}")

# 8. 初始化验证数据集
logger.info("Init Validation Dataset")
validation_dataset = np.load(data_paths['val_data'], mmap_mode='r')
logger.info(f"Validation dataset loaded, size: {len(validation_dataset)}")

# 9. 训练循环
logger.info("Init Training Loop")

total_tokens = len(training_dataset)
total_steps = int(train_config["total_epochs"] * total_tokens) // (train_config["batch_size"] * model_config["context_length"])

logger.info(f"Total tokens: {total_tokens}, Total steps: {total_steps}")
logger.info(f"Total epochs: {train_config['total_epochs']}, Batch size: {train_config['batch_size']}, Context length: {model_config['context_length']}")

logger.info("Start Training")
# for epoch in range(train_config["total_epochs"]):
# pass
for step in tqdm(range(total_steps), desc="Training", unit="step"):
# 清空梯度
optimizer.zero_grad()

# 使用余弦退火更新学习率
lr_new = learning_rate_schedule(
it=step,
max_learning_rate=optimizer_config["lr"],
min_learning_rate=optimizer_config["lr"] * 0.01,
warmup_iters=int(0.05 * total_steps),
cosine_cycle_iters=total_steps,
)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_new

# 获取batch数据
inputs, targets = get_batch(
training_dataset,
batch_size=train_config["batch_size"],
context_length=model_config["context_length"],
device=device
)

# 前向传播
logits = model(inputs) # shape: (batch_size, context_length, vocab_size)

# 计算损失
# 需要将 logits 和 targets 重塑为 cross_entropy 期望的形状
# logits: (batch_size, context_length, vocab_size) -> (batch_size * context_length, vocab_size)
# targets: (batch_size, context_length) -> (batch_size * context_length,)
batch_size, context_length, vocab_size = logits.shape
logits_reshaped = logits.view(batch_size * context_length, vocab_size)
targets_reshaped = targets.view(batch_size * context_length)
loss = cross_entropy(logits_reshaped, targets_reshaped)

# 反向传播
loss.backward()

# 梯度裁剪
gradient_clipping(model.parameters(), max_l2_norm=optimizer_config["max_norm"])

# 优化器更新
optimizer.step()

# 日志记录
if step % train_config["log_freq"] == 0:
logger.info(f"Step {step}, Loss: {loss.item()}")

# 使用wandb记录损失和学习率
wandb.log({"train_loss": loss.item(), "lr": lr_new, "step": step})

# 保存检查点
if step % train_config["checkpoint_freq"] == 0:
checkpoint_path = f"{data_paths['checkpoint_dir']}/checkpoint_tinystories_v0_{step}.pt"
save_checkpoint(model, optimizer, step, checkpoint_path)
logger.info(f"Checkpoint saved: {checkpoint_path}")


# 10. 保存模型
logger.info("Save Model")
final_model_path = f"{data_paths['checkpoint_dir']}/final_model_tinystories_v0.pt"
save_checkpoint(model, optimizer, total_steps, final_model_path)
logger.info(f"Final model saved: {final_model_path}")

# 关闭 wandb
wandb.finish()


if __name__ == "__main__":
train()

Chapter #6

Generating text

既然我们已经能够训练模型,我们需要补齐的最后一块拼图就是从模型中生成文本的能力。回忆一下,语言模型接收一个长度为 (sequence_length) 的整数序列(可能是分批处理的),并生成一个大小为 (sequence_length

$
\time
$

vocab_size) 的矩阵,其中序列的每个元素都是一个概率分布,用于预测该位置之后的下一个词。我们现在将编写几个函数,将其转化为生成新序列的采样方案。​Softmax​:按照标准惯例,语言模型的输出是最终线性层的输出(即“logits”),因此我们需要通过 softmax 操作将其转化为归一化的概率,我们在前面的公式 10 中已经见过这个操作。

​**解码 (Decoding)**​:为了从我们的模型中生成文本(解码),我们将为模型提供一个前缀标记序列(即“提示词/prompt”),并要求它在词汇表上生成一个概率分布,以此预测序列中的下一个词。然后,我们将从这个词汇表的分布中进行采样,以确定下一个输出标记。

具体来说,解码过程的一步应该接收一个序列

$
x_{1\dots t
$

,并通过以下方程返回一个标记 $$x_{t+1$$:

$
P(x_{t+1} = i \mid x_{1\dots t}) = \frac{\exp(v_i)}{\sum_j \exp(v_j)}
$

$
v = \text{TransformerLM}(x_{1\dots t}) \space \space t \in \mathbb{R}^{\text{vocab_size}}
$

其中 TransformerLM 是我们的模型,它接收长度为 sequence_length 的序列作为输入,并生成大小为 (sequence_length

$
\time
$

vocab_size) 的矩阵;我们取该矩阵的最后一个元素,因为我们要寻找的是第 $$$$个位置的下一个词的预测。这为我们提供了一个基本的解码器,通过重复从这些单步条件概率中采样(将我们之前生成的输出标记追加到下一个解码时间步的输入中),直到生成序列结束标记 <|endoftext|>(或达到用户指定的最大生成标记数)。

​**解码技巧 (Decoder tricks)**​:我们将尝试使用小模型,而小模型有时会生成质量很低的文本。两个简单的解码技巧可以帮助解决这些问题。首先,在 温度缩放 (temperature scaling) 中,我们用温度参数

$
\ta
$

修改 softmax,新的 softmax 为:

$
\text{softmax}(v, \tau)i = \frac{\exp(v_i/\tau)}{\sum{j=1}^{|\text{vocab_size}|} \exp(v_j/\tau)} \tag{24}
$

请注意,当设定

$
\tau \rightarrow
$

时,会导致 $$$$ 中最大的元素占据主导地位,softmax 的输出变成集中在这个最大元素上的独热向量 (one-hot vector)。其次,另一个技巧是 *核采样 (nucleus sampling)​ 或 ​*​top-p 采样**,我们通过截断低概率词来修改采样分布。设 是我们从(经过温度缩放的)大小为 (vocab_size) 的 softmax 中得到的概率分布。带有超参数 的核采样根据以下方程生成下一个标记:

$
P(x_{t+1} = i|q) = \begin{cases} \frac{q_i}{\sum_{j \in V(p)} q_j} & \text{如果 } i \in V(p) \ 0 & \text{否则} \end{cases}
$

其中

$
V(p
$

是使得 $$\sum_{j \in V(p)} q_j \ge $$ 的最小索引集合。你可以很容易地计算这个量:首先按大小对概率分布 进行排序,然后选择最大的词汇元素,直到达到目标水平

$
\alph
$

(此处原文应指代 $$$$)。

举例说明(假设 p=0.9):

  • 累积概率 0.4 ≤ 0.9 → 保留 Token A
  • 累积概率 0.7 ≤ 0.9 → 保留 Token B
  • 累积概率 0.85 ≤ 0.9 → 保留 Token C
  • 累积概率 0.95 > 0.9 → 移除 Token D
  • 累积概率 1.0 > 0.9 → 移除 Token E

最终只从 A、B、C 这三个高概率词中采样,既保证了生成质量,又保留了一定的随机性和多样性。这就是为什么 Top-p 采样比简单的贪心搜索(总选最大概率)或完全随机采样更好——它在质量和多样性之间取得了平衡!

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import torch.nn.functional as F
from typing import Optional
from cs336_basics.module.transformer_lm import TransformerLM
from cs336_basics.tokenizer.tokenizer import Tokenizer

def top_p_sampling(logits: torch.Tensor, p: float = 0.9) -> torch.Tensor:
"""
Top-p (nucleus) sampling implementation.

Args:
logits: Tensor of shape (vocab_size,) containing logits
p: Cumulative probability threshold

Returns:
sampled_token: Sampled token index
"""
# 对 logits 进行 softmax 得到概率分布
probs = F.softmax(logits, dim=-1)

# 按概率降序排序
sorted_probs, sorted_indices = torch.sort(probs, descending=True)

# 计算累积概率
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

# 找到累积概率超过 p 的位置
# 保留累积概率 <= p 的 tokens
sorted_indices_to_remove = cumulative_probs > p

# 保证至少保留一个 token (即使第一个 token 的概率就超过了 p)
sorted_indices_to_remove[0] = False

# 将要移除的 token 的概率设为 0
sorted_probs[sorted_indices_to_remove] = 0.0

# 重新归一化概率
sorted_probs = sorted_probs / sorted_probs.sum()

# 从筛选后的分布中采样
sampled_sorted_index = torch.multinomial(sorted_probs, num_samples=1)
sampled_token = sorted_indices[sampled_sorted_index]

return sampled_token

def generate(
model: TransformerLM,
prompt_tokens: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 1.0,
top_p: Optional[float] = None,
eos_token_id: Optional[int] = None,
device: str = "cuda"
) -> torch.Tensor:
"""
从语言模型生成文本补全。

Args:
model: 训练好的 TransformerLM 模型
prompt_tokens: 输入的 prompt tokens,形状为 (seq_len,) 或 (1, seq_len)
max_new_tokens: 最大生成的 token 数量
temperature: 温度参数,用于控制采样的随机性
- temperature > 1: 更随机
- temperature < 1: 更确定
- temperature = 1: 标准采样
top_p: Top-p (nucleus) 采样的阈值,如果为 None 则不使用
eos_token_id: 结束 token 的 ID (如 <|endoftext|>),遇到则停止生成
device: 设备 ('cuda' 或 'cpu')

Returns:
generated_tokens: 生成的完整 token 序列 (包括 prompt),形状为 (total_seq_len,)
"""
current_model = model
current_model.eval() # 设置为评估模式

# 确保 prompt_tokens 是 2D 张量 (batch_size=1, seq_len)
current_prompt = prompt_tokens
if current_prompt.dim() == 1:
current_prompt = current_prompt.unsqueeze(0)

current_prompt = current_prompt.to(device)
generated = current_prompt.clone()

with torch.no_grad():
for _ in range(max_new_tokens):
# 获取当前序列的最后 context_length 个 tokens
# 避免超过模型的最大上下文长度
if generated.shape[1] > current_model.context_length:
input_tokens = generated[:, -current_model.context_length:]
else:
input_tokens = generated

# 前向传播获取 logits
logits = current_model(input_tokens) # shape: (1, seq_len, vocab_size)

# 只需要最后一个位置的 logits
next_token_logits = logits[:, -1, :] # shape: (1, vocab_size)
next_token_logits = next_token_logits.squeeze(0) # shape: (vocab_size,)

# 应用温度缩放
if temperature != 1.0:
next_token_logits = next_token_logits / temperature

# 采样下一个 token
if top_p is not None and top_p < 1.0:
# 使用 top-p 采样
next_token = top_p_sampling(next_token_logits, p=top_p)
else:
# 标准的多项式采样
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

# 将新生成的 token 添加到序列中
# next_token 形状是 (1,),需要变成 (1, 1) 才能与 generated (1, seq_len) 拼接
generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)

# 检查是否遇到结束 token
if eos_token_id is not None and next_token.item() == eos_token_id:
break

return generated.squeeze(0) # 返回 1D tensor

def load_model_from_checkpoint(
checkpoint_path: str,
model_config: dict,
device: str = "cuda"
) -> TransformerLM:
"""
从 checkpoint 加载模型。

Args:
checkpoint_path: checkpoint 文件路径
model_config: 模型配置字典
device: 设备

Returns:
loaded_model: 加载好的模型
"""
# 创建模型实例
loaded_model = TransformerLM(
vocab_size=model_config["vocab_size"],
context_length=model_config["context_length"],
num_layers=model_config["num_layers"],
d_model=model_config["d_model"],
num_heads=model_config["num_heads"],
d_ff=model_config["d_ff"],
rope_theta=model_config["rope_theta"],
)

# 加载 checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model = loaded_model.to(device)
loaded_model.eval()

return loaded_model

if __name__ == "__main__":
# 使用示例
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 模型配置(需要与训练时保持一致)
model_config = {
"vocab_size": 10000,
"context_length": 256,
"num_layers": 4,
"d_model": 512,
"num_heads": 16,
"d_ff": 1344,
"rope_theta": 10000,
}

# 加载模型
checkpoint_path = "/home/fdse/zjh/cs336/assignment1/assignment1-basics/cs336_basics/data/tinystories/checkpoints/final_model_tinystories_v0.pt"
tokenizer_dir = "/home/fdse/zjh/cs336/assignment1/assignment1-basics/cs336_basics/data/tinystories/tokenizer"

print("正在加载模型...")
model = load_model_from_checkpoint(checkpoint_path, model_config, device)
print("✓ 模型加载成功")

print("正在加载 tokenizer...")
tokenizer = Tokenizer.from_directory(tokenizer_dir, special_tokens=["<|endoftext|>"])
print("✓ Tokenizer 加载成功")

# 获取 EOS token ID
eos_token_bytes = "<|endoftext|>".encode("utf-8")
eos_token_id = tokenizer.inverse_vocab.get(eos_token_bytes)
print(f"EOS token ID: {eos_token_id}")

# 准备 prompt
prompt_text = "Once upon a time"
print(f"\n输入 Prompt: \"{prompt_text}\"")

prompt_tokens = torch.tensor(tokenizer.encode(prompt_text), dtype=torch.long)
print(f"Prompt tokens: {prompt_tokens.tolist()}")

# 生成文本
print("\n开始生成...")
generated_tokens = generate(
model=model,
prompt_tokens=prompt_tokens,
max_new_tokens=100,
temperature=0.8,
top_p=0.9,
eos_token_id=eos_token_id,
device=device
)

print(f"\n生成的 tokens 数量: {len(generated_tokens)}, 生成的 tokens: {generated_tokens.tolist()}")

# 解码回文本
generated_text = tokenizer.decode(generated_tokens.tolist())
print(f"\n生成的文本:\n{generated_text}")

Summary

暂时无法在Lark文档外展示此内容

TransformerLM 的结构。首先让我查看一下相关的依赖模块,以便完整理解每个组件的输入输出。

好的!现在我来为你详细梳理 TransformerLM 的整体结构,包括每个模块的输入输出形状和参数。

📊 TransformerLM 整体架构

整体流程

1
输入 token_ids → Embedding → [TransformerBlock × num_layers] → RMSNorm → Linear → 输出 logits

🔍 各模块详细分析

1️⃣ Embedding 层 (Token Embedding)

输入:

  • token_ids: (batch_size, seq_len) - 整数型 token ID

参数:

  • weight: (vocab_size, d_model) - 可学习的嵌入矩阵

输出:

  • (batch_size, seq_len, d_model)

作用: 将离散的 token ID 映射到连续的向量空间

2️⃣ TransformerBlock × num_layers

每个 TransformerBlock 包含以下组件(按执行顺序):

2.1 第一个 RMSNorm + MultiHeadSelfAttention + 残差连接

1
y = self.attn(self.rms_norm1(x)) + x
  • RMSNorm1
    • 输入: (batch_size, seq_len, d_model)
    • 参数: weight: (d_model,)
    • 输出: (batch_size, seq_len, d_model)
  • MultiHeadSelfAttention
    • 输入: (batch_size, seq_len, d_model)
    • 参数:
      • W_Q: (d_model, d_model)
      • W_K: (d_model, d_model)
      • W_V: (d_model, d_model)
      • W_O: (d_model, d_model)
    • 内部过程:
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      Q = W_Q @ x  →  (batch_size, seq_len, d_model)
      K = W_K @ x → (batch_size, seq_len, d_model)
      V = W_V @ x → (batch_size, seq_len, d_model)

      # 重塑为多头
      Q → (batch_size, num_heads, seq_len, head_dim) # head_dim = d_model // num_heads
      K → (batch_size, num_heads, seq_len, head_dim)
      V → (batch_size, num_heads, seq_len, head_dim)

      # 应用 RoPE (旋转位置编码)
      Q, K → RoPE 编码后仍为原形状

      # Scaled Dot-Product Attention (带因果掩码)
      Attention(Q, K, V) → (batch_size, num_heads, seq_len, head_dim)

      # 合并多头
      output → (batch_size, seq_len, d_model)

      # 输出投影
      W_O @ output → (batch_size, seq_len, d_model)
    • 输出: (batch_size, seq_len, d_model)
  • 残差连接​: y = attention_output + x

2.2 第二个 RMSNorm + SwiGLU (FFN) + 残差连接

1
z = self.ffn(self.rms_norm2(y)) + y
  • RMSNorm2
    • 输入: (batch_size, seq_len, d_model)
    • 参数: weight: (d_model,)
    • 输出: (batch_size, seq_len, d_model)
  • SwiGLU (Position-wise Feed Forward)
    • 输入: (batch_size, seq_len, d_model)
    • 参数:
      • W_1: (d_ff, d_model)
      • W_2: (d_model, d_ff)
      • W_3: (d_ff, d_model)
    • 内部计算:
      1
      2
      3
      4
      5
      6
      FFN(x) = W_2 @ (SiLU(W_1 @ x) ⊙ W_3 @ x)

      W_1 @ x → (batch_size, seq_len, d_ff)
      W_3 @ x → (batch_size, seq_len, d_ff)
      SiLU(W_1 @ x) ⊙ W_3 @ x → (batch_size, seq_len, d_ff) # ⊙ 是逐元素乘法
      W_2 @ (...) → (batch_size, seq_len, d_model)
    • 输出: (batch_size, seq_len, d_model)
  • 残差连接​: z = ffn_output + y

TransformerBlock 总输出: (batch_size, seq_len, d_model)

3️⃣ 最终 RMSNorm

输入:

  • (batch_size, seq_len, d_model) - 来自最后一个 TransformerBlock 的输出

参数:

  • weight: (d_model,)

输出:

  • (batch_size, seq_len, d_model)

4️⃣ Output Linear 层 (Language Model Head)

输入:

  • (batch_size, seq_len, d_model)

参数:

  • weight: (vocab_size, d_model)

输出:

  • (batch_size, seq_len, vocab_size) - 每个位置的词汇表概率 logits

📝 完整示例

假设参数为:

  • vocab_size = 50000
  • context_length = 512
  • num_layers = 12
  • d_model = 768
  • num_heads = 12
  • d_ff = 2048
  • batch_size = 8
  • seq_len = 128

数据流动:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
输入: (8, 128) [整数 token IDs]

Embedding: (8, 128) → (8, 128, 768)

TransformerBlock #1: (8, 128, 768) → (8, 128, 768)
TransformerBlock #2: (8, 128, 768) → (8, 128, 768)
...
TransformerBlock #12: (8, 128, 768) → (8, 128, 768)

RMSNorm: (8, 128, 768) → (8, 128, 768)

Output Linear: (8, 128, 768) → (8, 128, 50000)

输出: (8, 128, 50000) [每个位置的词汇表 logits]

关键设计特点

  1. RMSNorm​: 用于层归一化,比 LayerNorm 更简单高效
  2. RoPE​: 旋转位置编码,在 attention 的 Q 和 K 上应用
  3. SwiGLU​: 改进的前馈网络,比传统 ReLU FFN 性能更好
  4. Causal Mask​: 自回归的因果掩码,确保只能看到之前的 token
  5. 残差连接​: 在 attention 和 FFN 之后都有残差连接,便于梯度传播
Author

Alan Zeng

Posted on

2025-12-19

Updated on

2026-02-22

Licensed under

Comments