Approximating KL Divergence(VeRL中low_var_kl)

张开发
2026/4/9 22:28:11 15 分钟阅读

分享文章

Approximating KL Divergence(VeRL中low_var_kl)
实现非常简单为什么用这种形式减少方差还是保持变差不变ifkl_penaltylow_var_kl:klref_logprob-logprob ratiotorch.exp(kl)kld(ratio-kl-1).contiguous()returntorch.clamp(kld,min-10,max10)ppo_kl verl_F.masked_mean(-negative_approx_kl, response_mask) 为什么和sum[P(i)ln(P(i)/Q(i))]的形式不一样ln前的P(i)去哪了这是一个非常精准的观察您指出的问题触及了PPO中KL散度蒙特卡洛估计的核心机制。代码中看似“缺失”的P(i)P(i)P(i)即旧策略概率πold(a∣s)\pi_{\text{old}}(a|s)πold​(a∣s)其实通过采样过程隐式体现了而非显式相乘。以下是关键解析1. 代码逻辑还原negative_approx_kllog_prob-old_log_prob# log(π_new) - log(π_old)ppo_klmasked_mean(-negative_approx_kl)# mean[ log(π_old) - log(π_new) ]# mean[ log(π_old / π_new) ]这等价于计算1N∑i1Nlog⁡(πold(ai∣si)πnew(ai∣si))\frac{1}{N}\sum_{i1}^N \log\left(\frac{\pi_{\text{old}}(a_i|s_i)}{\pi_{\text{new}}(a_i|s_i)}\right)N1​i1∑N​log(πnew​(ai​∣si​)πold​(ai​∣si​)​)2. 为什么不需要显式乘P(i)P(i)P(i)KL散度定义DKL(πold∥πnew)Ea∼πold[log⁡πold(a∣s)πnew(a∣s)]D_{KL}(\pi_{\text{old}} \parallel \pi_{\text{new}}) \mathbb{E}_{a \sim \pi_{\text{old}}} \left[ \log \frac{\pi_{\text{old}}(a|s)}{\pi_{\text{new}}(a|s)} \right]DKL​(πold​∥πnew​)Ea∼πold​​[logπnew​(a∣s)πold​(a∣s)​]蒙特卡洛估计原理由于PPO中动作轨迹是从旧策略πold\pi_{\text{old}}πold​采样得到的即ai∼πold(⋅∣si)a_i \sim \pi_{\text{old}}(\cdot|s_i)ai​∼πold​(⋅∣si​)样本平均1N∑log⁡(πold/πnew)\frac{1}{N}\sum \log(\pi_{\text{old}}/\pi_{\text{new}})N1​∑log(πold​/πnew​)本身就是对期望Ea∼πold[⋯ ]\mathbb{E}_{a \sim \pi_{\text{old}}}[\cdots]Ea∼πold​​[⋯]的无偏估计。采样分布πold\pi_{\text{old}}πold​已隐含了权重P(i)P(i)P(i)无需再显式乘πold(ai∣si)\pi_{\text{old}}(a_i|s_i)πold​(ai​∣si​)。3. 对比精确计算为何不采用若显式计算完整KL散度需遍历整个动作空间DKL∑a∈Aπold(a∣s)log⁡πold(a∣s)πnew(a∣s)D_{KL} \sum_{a \in \mathcal{A}} \pi_{\text{old}}(a|s) \log \frac{\pi_{\text{old}}(a|s)}{\pi_{\text{new}}(a|s)}DKL​a∈A∑​πold​(a∣s)logπnew​(a∣s)πold​(a∣s)​问题在语言模型等场景中动作空间A\mathcal{A}A是整个词表数万维遍历计算开销极大。PPO的工程选择用单次采样轨迹上的样本平均近似KL计算高效且与策略梯度采样逻辑一致。4. 方向性澄清关键代码计算的是DKL(πold∥πnew)D_{KL}(\pi_{\text{old}} \parallel \pi_{\text{new}})DKL​(πold​∥πnew​)前向KL因采样自πold\pi_{\text{old}}πold​。知识库[6][8]提到大模型常用反向KLDKL(πnew∥πref)D_{KL}(\pi_{\text{new}} \parallel \pi_{\text{ref}})DKL​(πnew​∥πref​)参考模型约束此时需从πnew\pi_{\text{new}}πnew​采样。但本代码片段针对PPO内部新旧策略对比方向由采样源决定。知识库[9]指出PPO监控KL爆炸时关注的是新策略πθ\pi_{\theta}πθ​与旧策略πθold\pi_{\theta_{\text{old}}}πθold​​的偏离此处计算逻辑完全匹配。总结代码中“缺失”的P(i)P(i)P(i)并非遗漏而是蒙特卡洛估计的天然特性采样分布πold\pi_{\text{old}}πold​作为概率权重已通过采样过程嵌入。这是强化学习中高效估计期望值的标准做法在保证理论正确性的同时极大降低计算复杂度也是PPO能在大模型训练中实用的关键设计之一。KL散度和F散度之间的关系f是凸函数期望Eq[r]1E_q[r]1Eq​[r]1f散度是非负的KL散度也就是非负的关于KL和RKL在f散度下的不同表达这也是为啥对于KL(p||q)对应xf(x)的原因KL(p∣∣q)Df(p∣∣q)∑xq(x)p(x)q(x)logp(x)q(x)Ex∼q(rlogr) KL(p||q)D_f(p||q)\sum_{x}q(x)\frac{p(x)}{q(x)}log\frac{p(x)}{q(x)}E_{x \sim q}(rlogr)KL(p∣∣q)Df​(p∣∣q)x∑​q(x)q(x)p(x)​logq(x)p(x)​Ex∼q​(rlogr)关于Bregman散度Bregman散度实际上是曲线和它的垂直切面距离不管是KL(p||q)还是KL(q||p)都是通过转换成Bregman散度来减少方差详细解释来自http://joschu.net/blog/kl-approx.htmlk1 in rewards 等价于 k2 in lossYoshua Bengio的一篇A Comedy of Estimators: On KL Regulations in RL Training of LLMs被ICRL 2026给拒了https://openreview.net/forum?idMkLHbwSMP3但是另外一篇RETHINKING KL REGULARIZATION IN RLHF: FROM VALUE ESTIMATION TO GRADIENT OPTIMIZATION也被ICRL 2026给拒了https://openreview.net/forum?idkeCnsHtIONOn a few pitfalls in KL divergence gradient estimation for RL 这篇发的更早一些这几篇的观点都是从梯度角度分析KL

更多文章