Training Language Models to Self-Correct via Reinforcement Learning

#rl #reasoning

前言

之前有非常多的工作探索了 critic model 和 revision model,业界也有闹出乌龙的 reflection 和刚出的 o1;但大家得出的实验结果通常表明直接拿 revise data 做 sft 是不合适的,会拿到很多负结果。这是因为我们简单且错误地相信,mo’xing’neng’g偶自己识别 revision 的 pattern。在这片工作里,作者先细粒度分析已有方法对 revision 能力的改变,再提出一套 multi-stage 训练方法。

已有方法分析

Base Model 的 self-correction 很差,STaR 和 Pair-SFT (Generating sequences by learning to self-correct) 都能够提升这一能力,但仍不会带来正面收益

作者也用 i->c, c->i rate 和 edit distance 说明了这些结论

Takeaways:

SCoRe

作者认为,强化模型 “修正” 的能力能强化 generalization,这是合理的,但之前方法没有注意区分 “the best first attempt” 和 “the best correction attempt”。他们定义了 multi-turn RL baseline 公式

如果直接按公式训,attempt1 和 attempt2 效果都会上升,但二者紧密相关,也就是说模型其实没有提升 correction 能力,所以作者进一步区分了第一步和第二步

第一步 保证第一轮生成稳定

他们用非常大的系数,保证生成 first attempt 的分布稳定,这里 β2\beta_2 非常大

image-20240923100434933

第二步 一起优化

这里带着 r1r_1 优化的原因是,作者不希望第一轮生成效果变坏

但这并没有体现出对 correction 能力的强调。因为要保证第一轮生成稳定,作者修改了 r2r_2

结果很好。