树搜索与搜索蒸馏:从 AlphaZero 到 LLM 推理
树搜索与搜索蒸馏:从 AlphaZero 到 LLM 推理
上一章讲了 RL 后训练(PPO、GRPO、OAPL),核心思路是采样 rollout → 打分 → 更新策略。这一章讲另一条路线:用树搜索在推理时找到更好的答案,再把搜索结果蒸馏回模型权重,让模型越来越强。
这条路线的鼻祖是 AlphaZero——DeepMind 用它在围棋上击败了人类世界冠军。现在人们正在尝试把同样的思路迁移到语言模型推理中。
一、什么是树搜索
1.1 从穷举到智能搜索
假设你在下棋。当前轮到你走,你有 20 种合法走法。对于每种走法,对手又有 20 种回应,然后你又有 20 种…
当前局面
├── 走法 A
│ ├── 对手回应 A1
│ │ ├── 你的走法 A1a → ...
│ │ └── 你的走法 A1b → ...
│ └── 对手回应 A2
│ └── ...
├── 走法 B
│ ├── ...
│ └── ...
└── 走法 C
└── ...
这就是一棵博弈树(game tree)。穷举搜索所有分支可以找到最优走法,但问题是分支太多了:
国际象棋:平均每步 35 种走法,一盘棋约 80 步
→ 搜索空间 ≈ 35^80 ≈ 10^123
围棋:平均每步 250 种走法,一盘棋约 150 步
→ 搜索空间 ≈ 250^150 ≈ 10^360
宇宙中原子总数才约 。穷举是不可能的,我们需要智能搜索——把有限的计算资源集中在最有潜力的方向上。
1.2 树搜索的基本思想
智能树搜索的核心是两个能力:
- 评估:能估计一个中间状态的好坏(不需要搜到终局)
- 选择:能决定优先搜索哪些分支(平衡探索与利用)
有了这两个能力,搜索过程就变成了:
重复很多次:
1. 从根节点出发,选择一条路径走到叶节点
2. 评估这个叶节点的价值
3. 把评估结果回传到路径上的所有节点
搜索次数越多,对每个分支的价值估计就越准,最终选出的走法就越好。
二、蒙特卡洛树搜索(MCTS)
MCTS 是目前最成功的树搜索算法,它用随机模拟来评估节点价值。
2.1 四个步骤
每一次 MCTS 迭代包含四步:
┌─────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ 选择 │ → │ 扩展 │ → │ 模拟 │ → │ 回传 │
│ Select │ │ Expand │ │ Simulate │ │ Backprop │
└─────────┘ └──────────┘ └──────────┘ └──────────┘
Step 1: 选择(Select)
从根节点开始,沿着树往下走。在每个节点,用一个公式选择最有价值的子节点(稍后详细讲这个公式)。一直走到一个还没完全展开的节点。
Step 2: 扩展(Expand)
在选中的节点上,创建一个新的子节点(尝试一个还没试过的动作)。
Step 3: 模拟(Simulate)
从新节点开始,用某种策略(比如随机走)一路走到游戏结束,得到一个结果(赢/输/平)。
Step 4: 回传(Backpropagate)
把模拟结果沿路径回传到根节点,更新路径上每个节点的统计信息:
- N(s):这个节点被访问了多少次
- W(s):这个节点累计获得的价值总和
- Q(s) = W(s) / N(s):平均价值
2.2 数值走一遍
假设一个简单的游戏树,已经搜索了一段时间:
根节点 (N=100)
├── A (N=60, W=36, Q=0.60) ← 访问 60 次,赢了 36 次
├── B (N=30, W=21, Q=0.70) ← 访问 30 次,赢了 21 次
└── C (N=10, W=2, Q=0.20) ← 访问 10 次,赢了 2 次
直觉上:
- A 被访问最多,胜率 60%
- B 被访问较少,但胜率最高 70%
- C 被访问最少,胜率最低 20%
下一次搜索应该选谁?如果只看 Q 值,永远选 B;如果只看访问次数少的,永远选 C。我们需要一个平衡探索和利用的公式。
三、UCT 与 pUCT:搜索的核心公式
3.1 UCT(Upper Confidence Bound for Trees)
UCT 公式决定了在每个节点选哪个子节点:
- :选择动作 a 后的平均价值(利用项)
- :探索常数,控制探索强度
- :父节点的访问次数
- :子节点的访问次数
第二项是探索奖励:访问次数少的子节点获得更高的探索奖励。
用上面的例子(取 c = 1.414):
UCT(A) = 0.60 + 1.414 * √(ln(100)/60) = 0.60 + 1.414 * √(4.605/60) = 0.60 + 0.392 = 0.992
UCT(B) = 0.70 + 1.414 * √(ln(100)/30) = 0.70 + 1.414 * √(4.605/30) = 0.70 + 0.554 = 1.254
UCT(C) = 0.20 + 1.414 * √(ln(100)/10) = 0.20 + 1.414 * √(4.605/10) = 0.20 + 0.959 = 1.159
结果:选 B(UCT = 1.254 最大)。B 兼具高胜率和相对较少的访问次数,值得继续探索。
3.2 pUCT(Predictor + UCT)
pUCT 是 AlphaZero 使用的变体,关键区别是引入了策略先验 P(s, a):
- :策略网络对动作 a 的先验概率
对比 UCT 和 pUCT 的探索项:
UCT 探索项: c * √(ln(N_parent) / N_child)
→ 只看访问次数,不管动作本身好不好
pUCT 探索项: c * P(s,a) * √(N_parent) / (1 + N_child)
→ 用策略先验 P(s,a) 加权,好动作获得更多探索机会
3.3 为什么 pUCT 对语言模型至关重要
在棋类中,合法走法通常几十到几百个,UCT 也能工作——反正每个动作都能被探索到。
但在语言模型中,即使在推理步级别,候选数也可以很多。更关键的是,模型的策略(logprobs)本身就包含了大量关于哪些推理步更合理的信息。
UCT 在语言模型中的问题:
候选推理步: "先算 3+5=8" (概率 0.4), "先算 1×2=2" (概率 0.3),
"把所有数加起来" (概率 0.01), "随便试一个" (概率 0.001)
UCT: 所有未访问的候选都有相同的探索优先级
→ 大量搜索资源浪费在 "随便试一个" 这类低质量候选上
pUCT 的解决:
P(s,a) = softmax(logprobs)
→ "先算 3+5=8" 的探索优先级远高于 "随便试一个"
→ 搜索资源集中在有潜力的方向上
DeepSeek-R1 团队报告 MCTS 效果不佳,一个可能原因就是他们用了 UCT 而非 pUCT,在高分支因子的语言空间中搜索效率太低。
四、AlphaZero 的搜索→蒸馏范式
4.1 AlphaZero 的训练循环
AlphaZero 的训练过程非常优雅:
初始化: 随机策略网络 π_θ + 随机价值网络 V_θ
重复:
1. 自我对弈:
- 用 MCTS + π_θ + V_θ 下很多盘棋
- 每步选择: MCTS 搜索 800 次迭代 → 按访问计数选择走法
- 记录每步的 (棋盘状态, MCTS 搜索策略, 最终胜负)
2. 训练(蒸馏):
- 策略目标: 让 π_θ(s) 逼近 MCTS 搜索策略(通常是 visit count 的归一化分布)
- 价值目标: 让 V_θ(s) 逼近实际胜负结果
- 损失 = 策略交叉熵 + 价值均方误差
3. 更新后重复
核心 insight:搜索策略比原始策略强,蒸馏让模型学会”直觉性地”做出搜索级别的决策。
4.2 为什么蒸馏有效
直觉上理解:
没搜索时的 π_θ: "根据经验,A 走法 40%,B 走法 35%,C 走法 25%"
→ 仅凭"直觉"(一次 forward pass)
搜索增强后: "搜索了 800 次后发现,A 走法其实只有 20%,
B 走法最好占 60%,C 走法 20%"
→ "深思熟虑"后的判断
蒸馏: 让 π_θ 学会直接输出 [20%, 60%, 20%]
→ 把"深思熟虑"的结论变成新的"直觉"
每轮迭代后,模型的”直觉”越来越接近”深思熟虑”的结果,搜索就能在这个更强的直觉基础上探索更深层的策略——良性循环。
4.3 数值理解蒸馏效果
训练前:
π_θ(A) = 0.4, MCTS(A) = 0.2 → 交叉熵损失推动 π_θ(A) 下降
π_θ(B) = 0.35, MCTS(B) = 0.6 → 交叉熵损失推动 π_θ(B) 上升
π_θ(C) = 0.25, MCTS(C) = 0.2 → 交叉熵损失推动 π_θ(C) 下降
训练后:
π_θ(A) ≈ 0.22, π_θ(B) ≈ 0.58, π_θ(C) ≈ 0.20
→ 不用搜索也能给出接近搜索结果的判断
五、从棋类到语言模型:关键挑战
5.1 动作空间的差异
这是最根本的差异。棋类的动作空间有限且语义清晰,语言模型的动作空间巨大且充满冗余。
国际象棋某局面: 可选走法 = [e4, d4, Nf3, c4, ...] → 约 30 个,每个语义独立
围棋某局面: 可选走法 = [A1, A2, ..., T19] → 最多 361 个
语言模型某推理步:
token 级别: 词表 ≈ 100,000 个 token
推理步级别: 可能的下一步推理 = 无限(但可以采样 K 个候选)
5.2 蒸馏的粒度不匹配
在 AlphaZero 中,搜索和策略的粒度一致——都是”走法”级别,可以直接计算 KL 散度来蒸馏。
在语言模型中:
- 搜索在推理步级别进行
- 模型策略在 token 级别定义
两者粒度不同,无法直接蒸馏。
AlphaZero 蒸馏:
搜索策略: [A: 60%, B: 30%, C: 10%] ← 走法级别
模型策略: [A: 40%, B: 35%, C: 25%] ← 走法级别
→ 直接计算 KL 散度,用交叉熵训练
语言模型蒸馏的困境:
搜索策略: [推理步X: 50%, 推理步Y: 30%, 推理步Z: 20%] ← 推理步级别
模型策略: [token1: 5%, token2: 3%, ...] ← token 级别
→ 粒度不匹配,不能直接比较
解决方案(Tree Search Distillation 的做法):不蒸馏搜索策略分布,而是选出最优轨迹,用 RL(PPO)训练模型在这些轨迹上获得高回报。
5.3 Value Function 的困难
AlphaZero 的 value function 预测胜负概率,输出范围固定在 [-1, 1],而且棋局的价值随着走子逐步确定,学习信号清晰。
语言模型推理中,中间推理步的”价值”很难定义:
问题: "用 3, 5, 7, 8 凑出 24"
推理步 1a: "先算 8-5=3" → 价值?此时不知道能不能最终凑出 24
推理步 1b: "先算 3×8=24" → 价值?看起来很有希望(已经得到 24 了)
推理步 1c: "先算 7+5=12" → 价值?还需要几步才知道
实际上 1b 看似好,但 24=3×8 用掉了 3 和 8,
剩下 5 和 7,还需要保证它们不影响结果。
这取决于具体的游戏规则。
解决方案:训练一个 value head(MLP + tanh),随 RL 训练同步更新。初始估计不准,但随着训练积累数据会逐渐改善。
六、推理步级别的 MCTS
6.1 为什么不在 token 级别搜索
直观理解:
token 级别搜索:
"先" → 搜索下一个 token
├── "算" (概率 0.3)
├── "计" (概率 0.25)
├── "把" (概率 0.15)
└── "用" (概率 0.1)
"先算" → 搜索下一个 token
├── "3" (概率 0.4)
├── "8" (概率 0.3)
└── ...
问题:
1. "先算" 和 "先计算" 导向几乎相同的推理,搜索资源被浪费
2. 搜索树极其庞大(每层 100K 分支)
3. 大多数 token 选择对推理路径没有实质影响
推理步级别搜索:
根节点 → 搜索下一个完整推理步
├── "先算 3×8=24,但还剩 5 和 7 需要处理"
├── "先算 8-5=3,得到两个 3 和一个 7"
├── "先算 7-3=4,然后考虑 5 和 8"
└── "先算 5+7=12,然后考虑 3 和 8"
优势:
1. 每个分支语义差异显著
2. 搜索树规模可控(每层 K=4 个候选)
3. 搜索资源全部用于探索不同的推理路径
6.2 搜索过程详解
以 Countdown 任务为例:用 3, 5, 7, 8 凑出目标 24。
根节点: "用 3, 5, 7, 8 凑出 24"
│
├── [Step 1a] "8 × 3 = 24。但我需要用所有四个数..."
│ │ → V(s) = 0.3 (value head 估计)
│ │
│ ├── [Step 2a] "24 + 5 - 7 = 22,不对"
│ │ → 终端节点,r = 1 - 2×|24-22|/24 = 0.833
│ │
│ └── [Step 2b] "24 × (7-5)/... 不行,已经用了所有数"
│ → 终端节点,r = -1.0 (格式错误)
│
├── [Step 1b] "8 - 5 = 3,现在有 3, 3, 7"
│ │ → V(s) = 0.4
│ │
│ ├── [Step 2c] "3 × 7 = 21,加 3 = 24!"
│ │ → 终端节点,r = 1.0 ✓
│ │
│ └── [Step 2d] "3 + 3 = 6,6 × 7 = 42,太大了"
│ → 终端节点,r = 1 - 2×|24-42|/24 = -0.5
│
├── [Step 1c] "7 - 3 = 4,现在有 4, 5, 8"
│ │ → V(s) = 0.35
│ │
│ └── [Step 2e] "5 - 4 = 1, 不太有用..."
│ → 继续搜索...
│
└── [Step 1d] "5 + 7 = 12,现在有 3, 8, 12"
│ → V(s) = 0.45
│
└── [Step 2f] "12 × (8/3) = 32,不对"
→ 终端节点,r = 1 - 2×|24-32|/24 = 0.333
经过多次 MCTS 迭代后,Step 1b 的分支因为发现了正确答案(3×7+3=24),会获得大量访问和高价值。最终选择 Step 1b 开头的轨迹作为训练样本。
6.3 序列级先验的计算
在推理步级别做 pUCT,需要为每个候选推理步计算先验概率 P(s, a)。
做法:对每个候选序列,累加所有 token 的 log probability,然后对所有候选做 softmax:
4 个候选推理步:
Step A: 生成了 15 个 token, sum(logprobs) = -8.2
Step B: 生成了 20 个 token, sum(logprobs) = -12.5
Step C: 生成了 12 个 token, sum(logprobs) = -6.8
Step D: 生成了 18 个 token, sum(logprobs) = -15.1
直接用概率? exp(-8.2) = 0.000275, exp(-12.5) = 0.0000037, ...
→ 数值极小,难以区分
Softmax 归一化:
scores = [-8.2, -12.5, -6.8, -15.1]
exp(scores) = [0.000275, 3.7e-6, 0.00111, 2.7e-7]
sum = 0.001389
P = [0.198, 0.003, 0.799, 0.0002]
→ Step C 先验最高 (79.9%)
→ Step A 次之 (19.8%)
→ Step B 和 D 很低
用相对概率(softmax)而非原始概率,避免了长序列累积 logprob 趋近负无穷的数值问题。
七、并行搜索与 Virtual Loss
7.1 为什么要并行搜索
MCTS 的每次迭代是串行的:选择→扩展→模拟→回传。但我们有多块 GPU,如何利用并行性?
串行 MCTS: 1 个 worker,每秒 100 次迭代
→ 1000 次迭代需要 10 秒
并行 MCTS: 16 个 worker 共享同一棵树
→ 理论上 1000 次迭代只需 0.625 秒
→ 但直接并行有问题...
7.2 并行搜索的问题
如果 16 个 worker 同时开始搜索,它们会用相同的 UCT/pUCT 公式选择节点——结果所有 worker 都选了同一个分支!
16 个 worker 同时搜索:
Worker 1: 选择 → 分支 B (Q=0.7, 当前最优)
Worker 2: 选择 → 分支 B (同样的公式,同样的结果)
Worker 3: 选择 → 分支 B
...
Worker 16: 选择 → 分支 B
结果: 16 个 worker 全去了同一个分支,毫无多样性
7.3 Virtual Loss 解决方案
Virtual Loss 的思路:当一个 worker 开始探索某个节点时,临时给这个节点加一个”虚拟的失败”,降低它对后续 worker 的吸引力。
Worker 1 选择分支 B:
→ 临时给 B 加 virtual loss: N(B) += 1, W(B) += 0 (加一次访问但不加胜利)
→ B 的 Q 值从 0.70 下降到 W/(N+1) = 21/31 = 0.677
Worker 2 重新计算:
A: Q=0.60, UCT=0.992
B: Q=0.677, UCT=1.198 ← Q 降低了
C: Q=0.20, UCT=1.159
→ 可能选 B,也可能选 C(差距缩小了)
Worker 3, 4, ... 每选一次 B,B 的吸引力进一步下降
→ 最终不同 worker 会分散到不同分支
当 worker 完成搜索后,撤销 virtual loss,用真实结果更新统计。
这就像餐厅选座:如果每个人都涌向最好的位置,就太挤了。Virtual loss 相当于”有人坐了就扣分”,自动引导后来者分散到其他位置。
八、知识蒸馏基础
8.1 什么是知识蒸馏
知识蒸馏(Knowledge Distillation)最初由 Hinton 等人在 2015 年提出,核心思想非常简单:
让一个小模型(学生)模仿一个大模型(老师)的行为。
老师模型(大): "2+3=?" → 输出概率分布 [5: 0.85, 4: 0.05, 6: 0.05, 其他: 0.05]
学生模型(小): "2+3=?" → 输出概率分布 [5: 0.40, 4: 0.20, 6: 0.15, 其他: 0.25]
蒸馏目标: 让学生模型的输出分布逼近老师模型
→ 训练后学生: [5: 0.82, 4: 0.06, 6: 0.06, 其他: 0.06]
8.2 为什么不直接用标签训练
用标准答案(hard label)训练:
标签: "答案是 5" → one-hot 向量 [5: 1.0, 其他: 0.0]
→ 学生只知道"5 是对的",不知道"4 比 7 更接近正确答案"
用老师的输出分布(soft label)训练:
老师分布: [5: 0.85, 4: 0.05, 6: 0.05, 3: 0.03, 7: 0.02, ...]
→ 学生额外学到: "4 和 6 比其他答案更接近正确"
→ 这些"暗知识"(dark knowledge)帮助学生泛化得更好
8.3 蒸馏在搜索中的应用
在 AlphaZero 范式中,“老师”不是一个更大的模型,而是搜索增强后的同一个模型:
传统蒸馏: 大模型(老师)→ 小模型(学生)
搜索蒸馏: 模型 + 搜索(老师)→ 同一模型不带搜索(学生)
搜索蒸馏的目标:让模型不用搜索就能做出搜索级别的决策。
搜索前: π_θ("先算 8-5=3") = 0.25 → 需要 800 次 MCTS 迭代才知道这步最好
搜索后: π_θ("先算 8-5=3") = 0.65 → 一次 forward pass 就知道这步最好
8.4 SFT 蒸馏 vs RL 蒸馏
SFT 蒸馏(行为克隆):
- 把搜索找到的最优轨迹作为训练数据
- 用标准的 next-token prediction 训练
- 优点:简单稳定
- 缺点:只学”做什么”,不学”为什么”;没有负样本信号
RL 蒸馏(如 PPO):
- 搜索找到的轨迹带有奖励信号
- 用 RL 算法(PPO、GRPO 等)训练
- 优点:同时学习正面和负面信号;可以泛化到新情境
- 缺点:训练更复杂,需要仔细调参
SFT 蒸馏:
训练数据: [("问题", "最优推理轨迹")]
损失: -log P(最优轨迹 | 问题)
→ 模型学会模仿搜索找到的好轨迹
RL 蒸馏:
训练数据: [("问题", "推理轨迹", 奖励分数)]
损失: PPO objective(鼓励高奖励轨迹,抑制低奖励轨迹)
→ 模型学会区分好轨迹和差轨迹的特征
RL 蒸馏的优势可以用 Tree Search Distillation 论文中的发现来说明:Best-of-N(本质上是 SFT 蒸馏)效果最差,因为模型学会了”碰运气”而非”稳健推理”。PPO 蒸馏有负信号(差轨迹被惩罚),模型被迫学会系统性地推理。
九、搜索蒸馏 vs 标准 RL 的对比
9.1 训练信号的质量
这是搜索蒸馏与标准 RL(如 GRPO)最核心的差异。
GRPO 的训练信号:
对每个问题,独立采样 G 条轨迹:
y1: "先算 3+5=8, 8×7=56, 56-8=48" → r=0 (错)
y2: "先算 8-5=3, 3×7=21, 21+3=24" → r=1 (对!)
y3: "先算 7×5=35, 35-8=27, 27-3=24" → r=1 (对!)
y4: "先算 3×5=15, 15+8=23, 23+7=30" → r=0 (错)
→ 鼓励 y2 和 y3,抑制 y1 和 y4
→ 每条轨迹独立,信息没有复用
MCTS 的训练信号:
构建搜索树(共享中间推理步):
Step 1a: "8-5=3" → 多条轨迹验证了这步的价值
Step 1b: "3+5=8" → 多条轨迹验证了这步的价值
Step 1c: "7×5=35" → 初始看起来不错,但后续分支多数失败
最终选出的轨迹: "8-5=3 → 3×7=21 → 21+3=24"
→ 这不是碰运气找到的,而是系统性搜索验证过的最优路径
关键差异:MCTS 的训练样本经过结构化验证,质量更高。
9.2 计算资源的利用方式
GRPO: 花 N 份算力 → 独立采样 N 条轨迹 → N 个独立的信号
MCTS: 花 N 份算力 → 构建一棵搜索树 → 树中信息相互关联、相互验证
好比:
GRPO = 派 N 个人各自独立地探索一片森林
MCTS = 派 N 个人协同搜索,共享发现,有组织地覆盖更多区域
9.3 新的 Scaling 维度
GRPO 的 scaling 主要靠增加 group size(采样更多轨迹)。
MCTS 提供了额外的 scaling 旋钮:
维度 1: 并行 worker 数 → 搜索多样性
维度 2: MCTS 迭代数 → 搜索深度
维度 3: 候选数 K → 每步的分支宽度
维度 4: value head 质量 → 搜索引导准确度
这意味着在 model scale 和 data scale 之外,search scale 可能是第三个提升 LLM 能力的 scaling axis。
十、完整训练架构
把所有组件组合在一起:
┌─────────────────────────────────────────────────┐
│ 训练循环 │
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ 数据采样 │ → │ MCTS │ → │ 轨迹选择 │ │
│ │ 问题 x │ │ 搜索 │ │ (max visit)│ │
│ └──────────┘ └──────────┘ └──────────┘ │
│ │ ↓ │
│ ┌────┴────┐ ┌──────────┐ │
│ │Value Head│ │ 共享Buffer│ │
│ │ V(s) │ │ (Redis) │ │
│ └─────────┘ └────┬─────┘ │
│ ↓ │
│ ┌──────────────┐ │
│ │ PPO 训练 │ │
│ │ L_ppo + │ │
│ │ L_value + │ │
│ │ L_KL │ │
│ └──────┬───────┘ │
│ ↓ │
│ ┌──────────────┐ │
│ │ 权重同步 │ │
│ │ (每 8 步) │ │
│ └──────────────┘ │
│ │
└─────────────────────────────────────────────────┘
角色分工:
- Generator(6 GPU):负责 MCTS 搜索,生成候选推理步,运行 value head
- Trainer(2 GPU):从 buffer 拉取轨迹,做 PPO 更新
- Rust Worker:协调任务分配,管理 gRPC 通信
- Redis:轨迹 buffer(stream)+ 权重同步(pub/sub)
十一、开放问题与思考
11.1 规模效应
搜索蒸馏目前只在小模型(1.5B)和简单任务(Countdown)上验证。更大模型的基础策略已经很强,搜索的边际收益是否递减?
1.5B 模型: 基础策略弱 → 搜索发现很多模型不知道的好路径 → 收益大
70B 模型: 基础策略强 → 搜索很难找到模型还不知道的路径 → 收益可能小
但也有反论:
更大的模型 + 更好的 value head → 搜索引导更准 → 搜索效率更高
这是一个实证问题,目前没有定论
11.2 任务适配性
组合问题(如 Countdown)天然适合树搜索——需要探索多种组合。
顺序推理问题(如 GSM8K)可能不那么适合——正确的推理路径通常是线性的,树搜索的优势不明显。
适合树搜索的任务:
- 组合优化(凑数、排列)
- 多步规划(下棋、代码生成)
- 多解问题(创意写作、数学证明)
不太适合的任务:
- 事实问答("法国首都是哪里")
- 线性计算("123 × 456")
- 翻译(通常只有一个最佳答案)
11.3 与 Test-Time Compute 的关系
搜索蒸馏发生在训练时——用搜索产生好的训练信号。
Test-Time Compute(TTC)发生在推理时——推理时花更多计算得到更好答案。
两者可以组合:
阶段 1 (训练): MCTS 搜索 → PPO 蒸馏 → 得到更强的基础模型
阶段 2 (推理): 用更强的模型 + TTC(多次采样取最优)→ 更好的答案
蒸馏后模型 + TTC > 蒸馏后模型 > 原始模型 + TTC > 原始模型
本章总结
- MCTS 基础:选择→扩展→模拟→回传,智能地分配搜索资源
- UCT vs pUCT:pUCT 用策略先验引导搜索,对高分支因子空间至关重要
- AlphaZero 范式:搜索增强策略 → 蒸馏回模型 → 迭代提升
- 语言模型的挑战:动作空间巨大、粒度不匹配、value function 难学
- 推理步级搜索:在语义有意义的粒度上分支,避免搜索资源浪费
- 并行 MCTS:virtual loss 鼓励 worker 分散探索
- 知识蒸馏:soft label 包含”暗知识”;RL 蒸馏优于 SFT 蒸馏
- 搜索蒸馏 vs GRPO:结构化搜索信号 vs 无结构独立采样
- 新的 Scaling 维度:search scale 可能与 model scale、data scale 并列
现在可以去读论文解析了: → Tree Search Distillation:用 MCTS + PPO 蒸馏搜索策略到语言模型