Esc
输入关键词开始搜索
AI Research

树搜索与搜索蒸馏:从 AlphaZero 到 LLM 推理

树搜索与搜索蒸馏:从 AlphaZero 到 LLM 推理

上一章讲了 RL 后训练(PPO、GRPO、OAPL),核心思路是采样 rollout → 打分 → 更新策略。这一章讲另一条路线:用树搜索在推理时找到更好的答案,再把搜索结果蒸馏回模型权重,让模型越来越强。

这条路线的鼻祖是 AlphaZero——DeepMind 用它在围棋上击败了人类世界冠军。现在人们正在尝试把同样的思路迁移到语言模型推理中。


一、什么是树搜索

1.1 从穷举到智能搜索

假设你在下棋。当前轮到你走,你有 20 种合法走法。对于每种走法,对手又有 20 种回应,然后你又有 20 种…

当前局面
├── 走法 A
│   ├── 对手回应 A1
│   │   ├── 你的走法 A1a → ...
│   │   └── 你的走法 A1b → ...
│   └── 对手回应 A2
│       └── ...
├── 走法 B
│   ├── ...
│   └── ...
└── 走法 C
    └── ...

这就是一棵博弈树(game tree)。穷举搜索所有分支可以找到最优走法,但问题是分支太多了:

国际象棋:平均每步 35 种走法,一盘棋约 80 步
  → 搜索空间 ≈ 35^80 ≈ 10^123

围棋:平均每步 250 种走法,一盘棋约 150 步
  → 搜索空间 ≈ 250^150 ≈ 10^360

宇宙中原子总数才约 108010^{80}。穷举是不可能的,我们需要智能搜索——把有限的计算资源集中在最有潜力的方向上。

1.2 树搜索的基本思想

智能树搜索的核心是两个能力:

  1. 评估:能估计一个中间状态的好坏(不需要搜到终局)
  2. 选择:能决定优先搜索哪些分支(平衡探索与利用)

有了这两个能力,搜索过程就变成了:

重复很多次:
  1. 从根节点出发,选择一条路径走到叶节点
  2. 评估这个叶节点的价值
  3. 把评估结果回传到路径上的所有节点

搜索次数越多,对每个分支的价值估计就越准,最终选出的走法就越好。


二、蒙特卡洛树搜索(MCTS)

MCTS 是目前最成功的树搜索算法,它用随机模拟来评估节点价值。

2.1 四个步骤

每一次 MCTS 迭代包含四步:

┌─────────┐    ┌──────────┐    ┌──────────┐    ┌──────────┐
│  选择    │ →  │  扩展     │ →  │  模拟     │ →  │  回传     │
│ Select  │    │ Expand   │    │ Simulate │    │ Backprop │
└─────────┘    └──────────┘    └──────────┘    └──────────┘

Step 1: 选择(Select)

从根节点开始,沿着树往下走。在每个节点,用一个公式选择最有价值的子节点(稍后详细讲这个公式)。一直走到一个还没完全展开的节点。

Step 2: 扩展(Expand)

在选中的节点上,创建一个新的子节点(尝试一个还没试过的动作)。

Step 3: 模拟(Simulate)

从新节点开始,用某种策略(比如随机走)一路走到游戏结束,得到一个结果(赢/输/平)。

Step 4: 回传(Backpropagate)

把模拟结果沿路径回传到根节点,更新路径上每个节点的统计信息:

  • N(s):这个节点被访问了多少次
  • W(s):这个节点累计获得的价值总和
  • Q(s) = W(s) / N(s):平均价值

2.2 数值走一遍

假设一个简单的游戏树,已经搜索了一段时间:

根节点 (N=100)
├── A (N=60, W=36, Q=0.60)   ← 访问 60 次,赢了 36 次
├── B (N=30, W=21, Q=0.70)   ← 访问 30 次,赢了 21 次
└── C (N=10, W=2,  Q=0.20)   ← 访问 10 次,赢了 2 次

直觉上:

  • A 被访问最多,胜率 60%
  • B 被访问较少,但胜率最高 70%
  • C 被访问最少,胜率最低 20%

下一次搜索应该选谁?如果只看 Q 值,永远选 B;如果只看访问次数少的,永远选 C。我们需要一个平衡探索和利用的公式。


三、UCT 与 pUCT:搜索的核心公式

3.1 UCT(Upper Confidence Bound for Trees)

UCT 公式决定了在每个节点选哪个子节点:

UCT(s,a)=Q(s,a)+clnN(s)N(s,a)\text{UCT}(s, a) = Q(s, a) + c \cdot \sqrt{\frac{\ln N(s)}{N(s, a)}}
  • Q(s,a)Q(s, a):选择动作 a 后的平均价值(利用项)
  • cc:探索常数,控制探索强度
  • N(s)N(s):父节点的访问次数
  • N(s,a)N(s, a):子节点的访问次数

第二项是探索奖励:访问次数少的子节点获得更高的探索奖励。

用上面的例子(取 c = 1.414):

UCT(A) = 0.60 + 1.414 * √(ln(100)/60) = 0.60 + 1.414 * √(4.605/60) = 0.60 + 0.392 = 0.992
UCT(B) = 0.70 + 1.414 * √(ln(100)/30) = 0.70 + 1.414 * √(4.605/30) = 0.70 + 0.554 = 1.254
UCT(C) = 0.20 + 1.414 * √(ln(100)/10) = 0.20 + 1.414 * √(4.605/10) = 0.20 + 0.959 = 1.159

结果:选 B(UCT = 1.254 最大)。B 兼具高胜率和相对较少的访问次数,值得继续探索。

3.2 pUCT(Predictor + UCT)

pUCT 是 AlphaZero 使用的变体,关键区别是引入了策略先验 P(s, a):

pUCT(s,a)=Q(s,a)+cpuctP(s,a)N(s)1+N(s,a)\text{pUCT}(s, a) = Q(s, a) + c_{puct} \cdot P(s, a) \cdot \frac{\sqrt{N(s)}}{1 + N(s, a)}
  • P(s,a)P(s, a):策略网络对动作 a 的先验概率

对比 UCT 和 pUCT 的探索项:

UCT 探索项:  c * √(ln(N_parent) / N_child)
  → 只看访问次数,不管动作本身好不好

pUCT 探索项: c * P(s,a) * √(N_parent) / (1 + N_child)
  → 用策略先验 P(s,a) 加权,好动作获得更多探索机会

3.3 为什么 pUCT 对语言模型至关重要

在棋类中,合法走法通常几十到几百个,UCT 也能工作——反正每个动作都能被探索到。

但在语言模型中,即使在推理步级别,候选数也可以很多。更关键的是,模型的策略(logprobs)本身就包含了大量关于哪些推理步更合理的信息

UCT 在语言模型中的问题:
  候选推理步: "先算 3+5=8" (概率 0.4), "先算 1×2=2" (概率 0.3),
               "把所有数加起来" (概率 0.01), "随便试一个" (概率 0.001)

  UCT: 所有未访问的候选都有相同的探索优先级
  → 大量搜索资源浪费在 "随便试一个" 这类低质量候选上

pUCT 的解决:
  P(s,a) = softmax(logprobs)
  → "先算 3+5=8" 的探索优先级远高于 "随便试一个"
  → 搜索资源集中在有潜力的方向上

DeepSeek-R1 团队报告 MCTS 效果不佳,一个可能原因就是他们用了 UCT 而非 pUCT,在高分支因子的语言空间中搜索效率太低。


四、AlphaZero 的搜索→蒸馏范式

4.1 AlphaZero 的训练循环

AlphaZero 的训练过程非常优雅:

初始化: 随机策略网络 π_θ + 随机价值网络 V_θ

重复:
  1. 自我对弈:
     - 用 MCTS + π_θ + V_θ 下很多盘棋
     - 每步选择: MCTS 搜索 800 次迭代 → 按访问计数选择走法
     - 记录每步的 (棋盘状态, MCTS 搜索策略, 最终胜负)

  2. 训练(蒸馏):
     - 策略目标: 让 π_θ(s) 逼近 MCTS 搜索策略(通常是 visit count 的归一化分布)
     - 价值目标: 让 V_θ(s) 逼近实际胜负结果
     - 损失 = 策略交叉熵 + 价值均方误差

  3. 更新后重复

核心 insight:搜索策略比原始策略强,蒸馏让模型学会”直觉性地”做出搜索级别的决策

4.2 为什么蒸馏有效

直觉上理解:

没搜索时的 π_θ: "根据经验,A 走法 40%,B 走法 35%,C 走法 25%"
                 → 仅凭"直觉"(一次 forward pass)

搜索增强后:     "搜索了 800 次后发现,A 走法其实只有 20%,
                 B 走法最好占 60%,C 走法 20%"
                 → "深思熟虑"后的判断

蒸馏:           让 π_θ 学会直接输出 [20%, 60%, 20%]
                → 把"深思熟虑"的结论变成新的"直觉"

每轮迭代后,模型的”直觉”越来越接近”深思熟虑”的结果,搜索就能在这个更强的直觉基础上探索更深层的策略——良性循环。

4.3 数值理解蒸馏效果

训练前:
  π_θ(A) = 0.4,  MCTS(A) = 0.2  → 交叉熵损失推动 π_θ(A) 下降
  π_θ(B) = 0.35, MCTS(B) = 0.6  → 交叉熵损失推动 π_θ(B) 上升
  π_θ(C) = 0.25, MCTS(C) = 0.2  → 交叉熵损失推动 π_θ(C) 下降

训练后:
  π_θ(A) ≈ 0.22, π_θ(B) ≈ 0.58, π_θ(C) ≈ 0.20
  → 不用搜索也能给出接近搜索结果的判断

五、从棋类到语言模型:关键挑战

5.1 动作空间的差异

这是最根本的差异。棋类的动作空间有限且语义清晰,语言模型的动作空间巨大且充满冗余。

国际象棋某局面: 可选走法 = [e4, d4, Nf3, c4, ...]  → 约 30 个,每个语义独立
围棋某局面:    可选走法 = [A1, A2, ..., T19]        → 最多 361 个

语言模型某推理步:
  token 级别: 词表 ≈ 100,000 个 token
  推理步级别: 可能的下一步推理 = 无限(但可以采样 K 个候选)

5.2 蒸馏的粒度不匹配

在 AlphaZero 中,搜索和策略的粒度一致——都是”走法”级别,可以直接计算 KL 散度来蒸馏。

在语言模型中:

  • 搜索在推理步级别进行
  • 模型策略在 token 级别定义

两者粒度不同,无法直接蒸馏。

AlphaZero 蒸馏:
  搜索策略: [A: 60%, B: 30%, C: 10%]   ← 走法级别
  模型策略: [A: 40%, B: 35%, C: 25%]   ← 走法级别
  → 直接计算 KL 散度,用交叉熵训练

语言模型蒸馏的困境:
  搜索策略: [推理步X: 50%, 推理步Y: 30%, 推理步Z: 20%]  ← 推理步级别
  模型策略: [token1: 5%, token2: 3%, ...]               ← token 级别
  → 粒度不匹配,不能直接比较

解决方案(Tree Search Distillation 的做法):不蒸馏搜索策略分布,而是选出最优轨迹,用 RL(PPO)训练模型在这些轨迹上获得高回报

5.3 Value Function 的困难

AlphaZero 的 value function 预测胜负概率,输出范围固定在 [-1, 1],而且棋局的价值随着走子逐步确定,学习信号清晰。

语言模型推理中,中间推理步的”价值”很难定义:

问题: "用 3, 5, 7, 8 凑出 24"

推理步 1a: "先算 8-5=3"     → 价值?此时不知道能不能最终凑出 24
推理步 1b: "先算 3×8=24"    → 价值?看起来很有希望(已经得到 24 了)
推理步 1c: "先算 7+5=12"    → 价值?还需要几步才知道

实际上 1b 看似好,但 24=3×8 用掉了 3 和 8,
剩下 5 和 7,还需要保证它们不影响结果。
这取决于具体的游戏规则。

解决方案:训练一个 value head(MLP + tanh),随 RL 训练同步更新。初始估计不准,但随着训练积累数据会逐渐改善。


六、推理步级别的 MCTS

6.1 为什么不在 token 级别搜索

直观理解:

token 级别搜索:
  "先" → 搜索下一个 token
    ├── "算" (概率 0.3)
    ├── "计" (概率 0.25)
    ├── "把" (概率 0.15)
    └── "用" (概率 0.1)

  "先算" → 搜索下一个 token
    ├── "3" (概率 0.4)
    ├── "8" (概率 0.3)
    └── ...

问题:
  1. "先算" 和 "先计算" 导向几乎相同的推理,搜索资源被浪费
  2. 搜索树极其庞大(每层 100K 分支)
  3. 大多数 token 选择对推理路径没有实质影响
推理步级别搜索:
  根节点 → 搜索下一个完整推理步
    ├── "先算 3×8=24,但还剩 5 和 7 需要处理"
    ├── "先算 8-5=3,得到两个 3 和一个 7"
    ├── "先算 7-3=4,然后考虑 5 和 8"
    └── "先算 5+7=12,然后考虑 3 和 8"

优势:
  1. 每个分支语义差异显著
  2. 搜索树规模可控(每层 K=4 个候选)
  3. 搜索资源全部用于探索不同的推理路径

6.2 搜索过程详解

以 Countdown 任务为例:用 3, 5, 7, 8 凑出目标 24。

根节点: "用 3, 5, 7, 8 凑出 24"

├── [Step 1a] "8 × 3 = 24。但我需要用所有四个数..."
│   │  → V(s) = 0.3 (value head 估计)
│   │
│   ├── [Step 2a] "24 + 5 - 7 = 22,不对"
│   │   → 终端节点,r = 1 - 2×|24-22|/24 = 0.833
│   │
│   └── [Step 2b] "24 × (7-5)/... 不行,已经用了所有数"
│       → 终端节点,r = -1.0 (格式错误)

├── [Step 1b] "8 - 5 = 3,现在有 3, 3, 7"
│   │  → V(s) = 0.4
│   │
│   ├── [Step 2c] "3 × 7 = 21,加 3 = 24!"
│   │   → 终端节点,r = 1.0 ✓
│   │
│   └── [Step 2d] "3 + 3 = 6,6 × 7 = 42,太大了"
│       → 终端节点,r = 1 - 2×|24-42|/24 = -0.5

├── [Step 1c] "7 - 3 = 4,现在有 4, 5, 8"
│   │  → V(s) = 0.35
│   │
│   └── [Step 2e] "5 - 4 = 1, 不太有用..."
│       → 继续搜索...

└── [Step 1d] "5 + 7 = 12,现在有 3, 8, 12"
    │  → V(s) = 0.45

    └── [Step 2f] "12 × (8/3) = 32,不对"
        → 终端节点,r = 1 - 2×|24-32|/24 = 0.333

经过多次 MCTS 迭代后,Step 1b 的分支因为发现了正确答案(3×7+3=24),会获得大量访问和高价值。最终选择 Step 1b 开头的轨迹作为训练样本。

6.3 序列级先验的计算

在推理步级别做 pUCT,需要为每个候选推理步计算先验概率 P(s, a)。

做法:对每个候选序列,累加所有 token 的 log probability,然后对所有候选做 softmax:

4 个候选推理步:
  Step A: 生成了 15 个 token, sum(logprobs) = -8.2
  Step B: 生成了 20 个 token, sum(logprobs) = -12.5
  Step C: 生成了 12 个 token, sum(logprobs) = -6.8
  Step D: 生成了 18 个 token, sum(logprobs) = -15.1

直接用概率? exp(-8.2) = 0.000275, exp(-12.5) = 0.0000037, ...
  → 数值极小,难以区分

Softmax 归一化:
  scores = [-8.2, -12.5, -6.8, -15.1]
  exp(scores) = [0.000275, 3.7e-6, 0.00111, 2.7e-7]
  sum = 0.001389
  P = [0.198, 0.003, 0.799, 0.0002]

  → Step C 先验最高 (79.9%)
  → Step A 次之 (19.8%)
  → Step B 和 D 很低

用相对概率(softmax)而非原始概率,避免了长序列累积 logprob 趋近负无穷的数值问题。


七、并行搜索与 Virtual Loss

7.1 为什么要并行搜索

MCTS 的每次迭代是串行的:选择→扩展→模拟→回传。但我们有多块 GPU,如何利用并行性?

串行 MCTS: 1 个 worker,每秒 100 次迭代
  → 1000 次迭代需要 10 秒

并行 MCTS: 16 个 worker 共享同一棵树
  → 理论上 1000 次迭代只需 0.625 秒
  → 但直接并行有问题...

7.2 并行搜索的问题

如果 16 个 worker 同时开始搜索,它们会用相同的 UCT/pUCT 公式选择节点——结果所有 worker 都选了同一个分支!

16 个 worker 同时搜索:
  Worker 1: 选择 → 分支 B (Q=0.7, 当前最优)
  Worker 2: 选择 → 分支 B (同样的公式,同样的结果)
  Worker 3: 选择 → 分支 B
  ...
  Worker 16: 选择 → 分支 B

结果: 16 个 worker 全去了同一个分支,毫无多样性

7.3 Virtual Loss 解决方案

Virtual Loss 的思路:当一个 worker 开始探索某个节点时,临时给这个节点加一个”虚拟的失败”,降低它对后续 worker 的吸引力。

Worker 1 选择分支 B:
  → 临时给 B 加 virtual loss: N(B) += 1, W(B) += 0 (加一次访问但不加胜利)
  → B 的 Q 值从 0.70 下降到 W/(N+1) = 21/31 = 0.677

Worker 2 重新计算:
  A: Q=0.60, UCT=0.992
  B: Q=0.677, UCT=1.198  ← Q 降低了
  C: Q=0.20, UCT=1.159

  → 可能选 B,也可能选 C(差距缩小了)

Worker 3, 4, ... 每选一次 B,B 的吸引力进一步下降
  → 最终不同 worker 会分散到不同分支

当 worker 完成搜索后,撤销 virtual loss,用真实结果更新统计。

这就像餐厅选座:如果每个人都涌向最好的位置,就太挤了。Virtual loss 相当于”有人坐了就扣分”,自动引导后来者分散到其他位置。


八、知识蒸馏基础

8.1 什么是知识蒸馏

知识蒸馏(Knowledge Distillation)最初由 Hinton 等人在 2015 年提出,核心思想非常简单:

让一个小模型(学生)模仿一个大模型(老师)的行为。

老师模型(大): "2+3=?" → 输出概率分布 [5: 0.85, 4: 0.05, 6: 0.05, 其他: 0.05]
学生模型(小): "2+3=?" → 输出概率分布 [5: 0.40, 4: 0.20, 6: 0.15, 其他: 0.25]

蒸馏目标: 让学生模型的输出分布逼近老师模型
  → 训练后学生: [5: 0.82, 4: 0.06, 6: 0.06, 其他: 0.06]

8.2 为什么不直接用标签训练

用标准答案(hard label)训练:

标签: "答案是 5" → one-hot 向量 [5: 1.0, 其他: 0.0]
  → 学生只知道"5 是对的",不知道"4 比 7 更接近正确答案"

用老师的输出分布(soft label)训练:

老师分布: [5: 0.85, 4: 0.05, 6: 0.05, 3: 0.03, 7: 0.02, ...]
  → 学生额外学到: "4 和 6 比其他答案更接近正确"
  → 这些"暗知识"(dark knowledge)帮助学生泛化得更好

8.3 蒸馏在搜索中的应用

在 AlphaZero 范式中,“老师”不是一个更大的模型,而是搜索增强后的同一个模型

传统蒸馏:   大模型(老师)→ 小模型(学生)
搜索蒸馏:   模型 + 搜索(老师)→ 同一模型不带搜索(学生)

搜索蒸馏的目标:让模型不用搜索就能做出搜索级别的决策。

搜索前: π_θ("先算 8-5=3") = 0.25  → 需要 800 次 MCTS 迭代才知道这步最好
搜索后: π_θ("先算 8-5=3") = 0.65  → 一次 forward pass 就知道这步最好

8.4 SFT 蒸馏 vs RL 蒸馏

SFT 蒸馏(行为克隆):

  • 把搜索找到的最优轨迹作为训练数据
  • 用标准的 next-token prediction 训练
  • 优点:简单稳定
  • 缺点:只学”做什么”,不学”为什么”;没有负样本信号

RL 蒸馏(如 PPO):

  • 搜索找到的轨迹带有奖励信号
  • 用 RL 算法(PPO、GRPO 等)训练
  • 优点:同时学习正面和负面信号;可以泛化到新情境
  • 缺点:训练更复杂,需要仔细调参
SFT 蒸馏:
  训练数据: [("问题", "最优推理轨迹")]
  损失: -log P(最优轨迹 | 问题)
  → 模型学会模仿搜索找到的好轨迹

RL 蒸馏:
  训练数据: [("问题", "推理轨迹", 奖励分数)]
  损失: PPO objective(鼓励高奖励轨迹,抑制低奖励轨迹)
  → 模型学会区分好轨迹和差轨迹的特征

RL 蒸馏的优势可以用 Tree Search Distillation 论文中的发现来说明:Best-of-N(本质上是 SFT 蒸馏)效果最差,因为模型学会了”碰运气”而非”稳健推理”。PPO 蒸馏有负信号(差轨迹被惩罚),模型被迫学会系统性地推理。


九、搜索蒸馏 vs 标准 RL 的对比

9.1 训练信号的质量

这是搜索蒸馏与标准 RL(如 GRPO)最核心的差异。

GRPO 的训练信号:

对每个问题,独立采样 G 条轨迹:
  y1: "先算 3+5=8, 8×7=56, 56-8=48" → r=0 (错)
  y2: "先算 8-5=3, 3×7=21, 21+3=24" → r=1 (对!)
  y3: "先算 7×5=35, 35-8=27, 27-3=24" → r=1 (对!)
  y4: "先算 3×5=15, 15+8=23, 23+7=30" → r=0 (错)

  → 鼓励 y2 和 y3,抑制 y1 和 y4
  → 每条轨迹独立,信息没有复用

MCTS 的训练信号:

构建搜索树(共享中间推理步):
  Step 1a: "8-5=3" → 多条轨迹验证了这步的价值
  Step 1b: "3+5=8" → 多条轨迹验证了这步的价值
  Step 1c: "7×5=35" → 初始看起来不错,但后续分支多数失败

  最终选出的轨迹: "8-5=3 → 3×7=21 → 21+3=24"
  → 这不是碰运气找到的,而是系统性搜索验证过的最优路径

关键差异:MCTS 的训练样本经过结构化验证,质量更高。

9.2 计算资源的利用方式

GRPO: 花 N 份算力 → 独立采样 N 条轨迹 → N 个独立的信号
MCTS: 花 N 份算力 → 构建一棵搜索树 → 树中信息相互关联、相互验证

好比:
  GRPO = 派 N 个人各自独立地探索一片森林
  MCTS = 派 N 个人协同搜索,共享发现,有组织地覆盖更多区域

9.3 新的 Scaling 维度

GRPO 的 scaling 主要靠增加 group size(采样更多轨迹)。

MCTS 提供了额外的 scaling 旋钮:

维度 1: 并行 worker 数 → 搜索多样性
维度 2: MCTS 迭代数   → 搜索深度
维度 3: 候选数 K      → 每步的分支宽度
维度 4: value head 质量 → 搜索引导准确度

这意味着在 model scale 和 data scale 之外,search scale 可能是第三个提升 LLM 能力的 scaling axis。


十、完整训练架构

把所有组件组合在一起:

┌─────────────────────────────────────────────────┐
│                    训练循环                       │
│                                                   │
│  ┌──────────┐    ┌──────────┐    ┌──────────┐   │
│  │ 数据采样  │ →  │  MCTS    │ →  │ 轨迹选择  │   │
│  │ 问题 x   │    │  搜索    │    │ (max visit)│   │
│  └──────────┘    └──────────┘    └──────────┘   │
│                       │                ↓          │
│                  ┌────┴────┐    ┌──────────┐     │
│                  │Value Head│    │ 共享Buffer│     │
│                  │  V(s)   │    │ (Redis)  │     │
│                  └─────────┘    └────┬─────┘     │
│                                      ↓           │
│                              ┌──────────────┐    │
│                              │  PPO 训练     │    │
│                              │  L_ppo +      │    │
│                              │  L_value +    │    │
│                              │  L_KL         │    │
│                              └──────┬───────┘    │
│                                     ↓            │
│                              ┌──────────────┐    │
│                              │ 权重同步      │    │
│                              │ (每 8 步)     │    │
│                              └──────────────┘    │
│                                                   │
└─────────────────────────────────────────────────┘

角色分工:

  • Generator(6 GPU):负责 MCTS 搜索,生成候选推理步,运行 value head
  • Trainer(2 GPU):从 buffer 拉取轨迹,做 PPO 更新
  • Rust Worker:协调任务分配,管理 gRPC 通信
  • Redis:轨迹 buffer(stream)+ 权重同步(pub/sub)

十一、开放问题与思考

11.1 规模效应

搜索蒸馏目前只在小模型(1.5B)和简单任务(Countdown)上验证。更大模型的基础策略已经很强,搜索的边际收益是否递减?

1.5B 模型: 基础策略弱 → 搜索发现很多模型不知道的好路径 → 收益大
70B 模型:  基础策略强 → 搜索很难找到模型还不知道的路径 → 收益可能小

但也有反论:
  更大的模型 + 更好的 value head → 搜索引导更准 → 搜索效率更高
  这是一个实证问题,目前没有定论

11.2 任务适配性

组合问题(如 Countdown)天然适合树搜索——需要探索多种组合。

顺序推理问题(如 GSM8K)可能不那么适合——正确的推理路径通常是线性的,树搜索的优势不明显。

适合树搜索的任务:
  - 组合优化(凑数、排列)
  - 多步规划(下棋、代码生成)
  - 多解问题(创意写作、数学证明)

不太适合的任务:
  - 事实问答("法国首都是哪里")
  - 线性计算("123 × 456")
  - 翻译(通常只有一个最佳答案)

11.3 与 Test-Time Compute 的关系

搜索蒸馏发生在训练时——用搜索产生好的训练信号。

Test-Time Compute(TTC)发生在推理时——推理时花更多计算得到更好答案。

两者可以组合:

阶段 1 (训练): MCTS 搜索 → PPO 蒸馏 → 得到更强的基础模型
阶段 2 (推理): 用更强的模型 + TTC(多次采样取最优)→ 更好的答案

  蒸馏后模型 + TTC > 蒸馏后模型 > 原始模型 + TTC > 原始模型

本章总结

  1. MCTS 基础:选择→扩展→模拟→回传,智能地分配搜索资源
  2. UCT vs pUCT:pUCT 用策略先验引导搜索,对高分支因子空间至关重要
  3. AlphaZero 范式:搜索增强策略 → 蒸馏回模型 → 迭代提升
  4. 语言模型的挑战:动作空间巨大、粒度不匹配、value function 难学
  5. 推理步级搜索:在语义有意义的粒度上分支,避免搜索资源浪费
  6. 并行 MCTS:virtual loss 鼓励 worker 分散探索
  7. 知识蒸馏:soft label 包含”暗知识”;RL 蒸馏优于 SFT 蒸馏
  8. 搜索蒸馏 vs GRPO:结构化搜索信号 vs 无结构独立采样
  9. 新的 Scaling 维度:search scale 可能与 model scale、data scale 并列

现在可以去读论文解析了:Tree Search Distillation:用 MCTS + PPO 蒸馏搜索策略到语言模型