LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145640762
GPRO,即 Group Relative Policy Optimization,分组相对的策略优化,是 PPO(Proximal Policy Optimization, 近端策略优化) 的优化版本,省略优化 评论家模型(Critic Model),用于估计价值(Value Function Model),降低模型训练的资源消耗。
GRPO 目标的工作原理如下:
- 为查询生成一组响应。
- 根据预定义的标准(例如准确性、格式),计算每个响应的奖励。
- 比较组内的反应以计算他们的相对优势。
- 更新策略以支持具有更高优势的响应,剪裁(clip)确保的稳定性。
- 规范更新以防止模型偏离基线太远。
GRPO 有效的原因:
- 无需评论:GRPO 依靠群体比较,避免对于单独评估者的需求,从而降低了计算成本。
- 稳定学习:剪裁(clip) 和 KL 正则化确保模型稳步改进,不会出现剧烈波动。
- 高效训练:通过关注相对性能,GRPO 非常适合推理等绝对评分困难的任务。
在 DeepSeekMath (2024.4) 中,使用 GPRO 代替 PPO。

回顾一下 PPO 模型的公式与框架,PPO 是先训练 奖励模型(RM),通过强化学习策略,将奖励模型的能力,学习到大语言模型中,同时,注意模型的输出符合之前的预期,不要偏离过远(KL Divergence)。即:
- RM(Reward Model, 奖励模型): m a x r ϕ { E ( x , y w i n , y l o s s ) ∼ D [ l o g σ ( r ϕ ( x , y w i n ) − r ϕ ( x , y l o s s ) ) ] } \underset{r_{\phi}}{max} \{ {E_{(x,y_{win},y_{loss}) \sim D}}[log \ \sigma(r_{\phi}(x,y_{win}) - r_{\phi}(x,y_{loss}))] \} rϕmax{E(x,ywin,yloss)∼D[log σ(rϕ(x,ywin)−rϕ(x,yloss))]}
- PPO(Proximal Policy Optimization, 近端策略优化): m a x π θ { E x ∼ D , y ∼ π θ ( y ∣ x ) [ r ϕ ( x , y ) ] − β D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] } \underset{\pi_{\theta}}{max} \{ E_{x \sim D,y \sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)] - \beta D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] \} πθmax{Ex∼D,y∼πθ(y∣x)[rϕ(x,y)]−βDKL[πθ(y∣x)∣∣πref(y∣x)]}
- KL 散度(KL Divergence):
D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] = π r e f ( y ∣ x ) l o g π r e f ( y ∣ x ) π θ ( y ∣ x ) = π r e f ( y ∣ x ) ( l o g π r e f ( y ∣ x ) − l o g π θ ( y ∣ x ) ) \begin{align} D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] &= \pi_{ref}(y|x) \ log{\frac{\pi_{ref}(y|x)}{\pi_{\theta}(y|x)}} \\ &= \pi_{ref}(y|x) (log \pi_{ref}(y|x) - log\pi_{\theta}(y|x)) \end{align} DKL[πθ(y∣x)∣∣πref(y∣x)]=πref(y∣x) logπθ(y∣x)πref(y∣x)=πref(y∣x)(logπref(y∣x)−logπθ(y∣x))
其中,Actor 和 Critic 损失函数如下:
a [ i , j ] = r e t u r n s [ i , j ] − v a l u e s [ i , j ] L o s s a c t o r = − 1 M N ∑ i = 1 M ∑ j = 1 N a [ i , j ] × e x p ( l o g _ p r o b [ i , j ] − o l d _ l o g _ p r o b [ i , j ] ) L o s s c r i t i c = 1 2 M N ∑ i = 1 M ∑ j = 1 N ( v a l u e s [ i , j ] − r e t u r n s [ i , j ] ) 2 L o s s = L o s s a c t o r + 0.1 ∗ L o s s c r i t i c \begin{align} a[i,j] &= returns[i,j] - values[i,j] \\ Loss_{actor} &= -\frac{1}{MN} \sum_{i=1}^{M} \sum_{j=1}^{N} a[i,j] \times exp(log\_prob[i,j]-old\_log\_prob[i,j]) \\ Loss_{critic} &= \frac{1}{2MN} \sum_{i=1}^{M} \sum_{j=1}^{N} (values[i,j] - returns[i,j])^{2} \\ Loss & = Loss_{actor} + 0.1*Loss_{critic} \end{align} a[i,j]LossactorLosscriticLoss=returns[i,j]−values[i,j]=−MN1i=1∑Mj=1∑Na[i,j]×exp(log_prob[i,j]−old_log_prob[i,j])=2MN1i=1∑Mj=1∑N(values[i,j]−returns[i,j])2=Lossactor+0.1∗Losscritic
PPO 的奖励(Reward 计算),一般而言,超参数 β = 0.1 \beta=0.1 β=0.1:
r t = r ψ ( q , o ≤ t ) − β l o g ( π θ ( o t ∣ q , o < t ) π r e f ( o t ∣ q , o < t ) ) r_{t} = r_{\psi}(q,o_{\leq t}) - \beta log(\frac{\pi_{\theta}(o_{t}|q,o_{<t})}{\pi_{ref}(o_{t}|q,o_{<t})}) rt=rψ(q,o≤t)−βlog(πref(ot∣q,o<t)πθ(ot∣q,o<t))
在 PPO 中使用的价值函数(Critic Model),通常与策略模型(Policy Model)大小相当,带来内存和计算负担。在强化学习训练中,价值函数作为基线,以减少优势函数计算中的方差。然而,在 大语言模型(LLM) 的场景中,只有最后一个 Token 被奖励模型赋予奖励分数,使训练一个在每个标记处都准确的价值函数,变得复杂。GRPO 无需像 PPO 那样,使用额外的近似价值函数,而是使用同一问题产生的多个采样输出的平均奖励,作为基线。
GRPO 使用基于 组相对(Group Relative) 的优势计算方式,与奖励模型比较特性一致,因为奖励模型通常是在同一问题上不同输出之间的比较数据集上进行训练的。同时,GRPO 没有在 奖励(Reward) 中加入 KL 惩罚,而是直接将训练策略与参考策略之间的KL散度添加到损失函数中,从而避免了在计算优势时增加复杂性。
GPRO 的公式, Q Q Q 表示 Query,即输入的问题,采样出问题 q q q,推理大模型,输出 G G G 个输出 o i o_{i} oi:
J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) ] L o s s = 1 G ∑ i = 1 G ( m i n ( ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) ) A i , c l i p ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) , 1 − ϵ , 1 + ϵ ) A i − β D K L ( π θ ∣ ∣ π r e f ) ) D K L ( π θ ∣ ∣ π r e f ) = π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − l o g π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − 1 A i = r i − m e a n ( { r 1 , r 2 , … , r G } ) s t d ( { r 1 , r 2 , … , r G } ) \begin{align} J_{GRPO}(\theta) &= \mathbb{E}[q \sim P(Q), \{{o_{i}}\}_{i=1}^{G} \sim \pi_{\theta_{old}}(O|q)] \\ Loss &= \frac{1}{G}\sum_{i=1}^{G}(min((\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)})A_{i}, clip(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)},1-\epsilon,1+\epsilon)A_{i}-\beta \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref})) \\ \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref}) &= \frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - log\frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - 1 \\ A_{i} &= \frac{r_{i}-mean(\{r_{1},r_{2},\ldots,r_{G}\})}{std(\{r_{1},r_{2},\ldots,r_{G}\})} \end{align} JGRPO(θ)LossDKL(πθ∣∣πref)Ai=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]=G1i=1∑G(min((πθold(oi∣q)πθ(oi∣q))Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ϵ,1+ϵ)Ai−βDKL(πθ∣∣πref))=πθ(oi∣q)πref(oi∣q)−logπθ(oi∣q)πref(oi∣q)−1=std({r1,r2,…,rG})ri−mean({r1,r2,…,rG})
GRPO 的 KL 散度,使用蒙特卡洛(Monte-Carlo) 近似计算 KL散度(Kullback-Leibler Divergence),结果始终为正数。
参考源码,TRL - GRPO:
# Advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# KL 散度
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# 期望
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# 联合 loss
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# mask loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
GRPO 的 训练源码:trl/trainer/grpo_trainer.py

PPO 的伪码流程:
policy_model = load_model()
ref_model = policy_model.copy() # 不更新
critic_model = load_reward_model(only_last=False)
reward_model = critic_mode.copy() # 不更新for i in steps:# 1. 采样阶段prompts = sample_prompt()# old_log_probs[i][j](from policy_model), old_values[i][j](from critic_model)responses, old_log_probs, old_values = respond(policy_model, critic_model, prompts)# 2. 反馈阶段scores = reward_model(prompts, responses)# ref_log_probs[i][j](from ref_model)ref_log_probs = analyze_responses(ref_model, prompts, responses) # ref logps# rewards[i][j] = scores[i] - (old_log_probs[i][j] - ref_log_prob[i][j])rewards = reward_func(scores, old_log_probs, ref_log_probs) # 奖励计算# advantages[i][j] = rewards[i][j] - old_values[i][j] advantages = advantage_func(rewards, old_values) # 奖励(r)-价值(v)=优势(a)# 3. 学习阶段for j in ppo_epochs: # 多次更新学习,逐渐靠近奖励log_probs = analyze_responses(policy_model, prompts, responses)values = analyze_responses(critic_model, prompts, responses)# 更新 actor(policy) 模型,学习更新的差异,advantages[i][j]越大,强化动作actor_loss = actor_loss_func(advantages, old_log_probs, log_probs) critic_loss = critic_loss_func(rewards, values) # 更新 critic 模型loss = actor_loss + 0.1 * critic_loss # 更新train(loss, policy_model.parameters(), critic_model.parameters()) # 参数
参考 知乎 - 图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读
KL 散度的实现,如下:
import torch
import torch.nn.functional as F
# 假设我们有两个概率分布 P 和 Q
P = torch.tensor([0.1, 0.2, 0.7]) # 参考的、真实的
Q = torch.tensor([0.2, 0.3, 0.5]) # 模型生成的
# 计算 Q 的对数概率
log_Q = torch.log(Q)
# 使用 PyTorch 的 kl_div 函数计算 KL 散度
kl_divergence = F.kl_div(log_Q, P, reduction='sum') # 注意先Q后P
print(f"KL Div (PyTorch): {kl_divergence}")
log_P = torch.log(P)
kl_elementwise = P * (log_P - log_Q)
# 对所有元素求和,得到 KL 散度
kl_divergence = torch.sum(kl_elementwise)
参考:
- 知乎 - GRPO: Group Relative Policy Optimization
- GitHub - GRPO Trainer
- Medium - The Math Behind DeepSeek: A Deep Dive into Group Relative Policy Optimization (GRPO)
相关文章:
LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/145640762 GPRO,即 Group Relative Policy Optimization,分组相对的策略优化,是 PPO(Proximal Policy Optimiz…...
Github 2025-02-14 Java开源项目日报 Top10
根据Github Trendings的统计,今日(2025-02-14统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目10C#项目1Guava: 谷歌Java核心库 创建周期:3725 天开发语言:Java协议类型:Apache License 2.0Star数量:49867 个Fork数量:10822 次…...
DeepSeek赋能制造业:图表可视化从入门到精通
一、企业数据可视化之困 在数字化浪潮席卷全球的当下,商贸流通企业作为经济活动的关键枢纽,每天都在与海量数据打交道。从商品的采购、库存管理,到销售渠道的拓展、客户关系的维护,各个环节都源源不断地产生数据。这些数据犹如一座蕴含巨大价值的宝藏,然而,如何挖掘并利用…...
Python爬虫技术
Python爬虫技术凭借其高效便捷的特性,已成为数据采集领域的主流工具。以下从技术优势、核心实现、工具框架、反爬策略及注意事项等方面进行系统阐述: 一、Python爬虫的核心优势 语法简洁与开发效率高 Python的语法简洁易读,配合丰富的第三方库…...
C++Primer学习(4.6成员访问运算符)
4.6成员访问运算符 点运算符和箭头运算符都可用于访问成员,其中,点运算符获取类对象的一个成员;箭头运算符与点运算符有关,表达式 ptr->mem等价于(* ptr).mem: string sl"a string",*p &s1; auto ns1.size();//运行string对…...
c++14之std::make_unique
基础介绍 虽然在c11版本std::unique_ptr<T>已经引入,但是在c14版本引入之前,std::unique_ptr<T>的创建还是通过new操作符来完成的。在c14版本已经引入了类似make_shared的std::make_unique,目的是提供更加安全的方法创建std::un…...
服务器linux操作系统安全加固
一、系统更新与补丁管理 更新系统sudo yum update -y # 更新所有软件包 sudo yum install epel-release -y # 安装EPEL扩展源启用自动安全更新sudo yum install yum-cron -y sudo systemctl enable yum-cron sudo systemctl start yum-cron配置 /etc/yum/yum-cron.con…...
原生Three.js 和 Cesium.js 案例 。 智慧城市 数字孪生常用功能列表
对于大多数的开发者来言,看了很多文档可能遇见不到什么有用的,就算有用从文档上看,把代码复制到自己的本地大多数也是不能用的,非常浪费时间和学习成本, 尤其是three.js , cesium.js 这种难度较高ÿ…...
Node.js中Express框架使用指南:从入门到企业级实践
目录 一、Express快速入门 1. 项目初始化 2. 基础服务搭建 3. 添加热更新 二、核心功能详解 1. 路由系统 动态路由参数 路由模块化 2. 中间件机制 自定义中间件 常用官方中间件 3. 模板引擎集成 三、企业级最佳实践 1. 项目结构规范 2. 错误处理方案 3. 安全防护…...
spring 学习 (注解)
目录 前言 常用的注解 须知 1 Conponent注解 demo(案例) 2 ControllerServiceRepository demo(案例) 3 ScopeLazyPostConstructPreDestroy demo(案例) 4 ValueAutowiredQualifierResource demo(案例) 5 Co…...
计算机等级考试——计算机三级——网络技术部分
计算机三级——网络技术部分 一、外部网关协议BGP考点二、IPS入侵防护系统考点三、OSPF协议考点四、弹性分组环——RPR技术 一、外部网关协议BGP考点 高频考点,中考次数:25次 这类知识采用背诵的方式,可以更快速地备考。 BGP是边界网关协议&…...
新版电脑通过wepe安装系统
官方下载链接 WIN10下载 WIN11下载 微PE 启动盘制作 1:选择启动盘的设备 2:选择对应的U盘设备,点击安装就可以,建议大于8g 3:在上方链接下载需要安装的程序包,放入启动盘,按需 更新系统 …...
oracle中decode怎么转换成pg
对于 PostgreSQL 中的 Oracle DECODE 函数,可以使用 CASE 表达式或联合。CASE 表达式根据条件返回第一个匹配的结果,语法为:CASE WHEN 条件 THEN 结果 ELSE 结果 END。联合通过 UNION ALL 操作符组合多个 SELECT 语句,返回一个包含…...
【NLP】循环神经网络RNN
目录 一、词嵌入层 二、循环网络层 2.1 RNN网络原理 2.2 Pytorch RNN API 自然语言处理(Nature language Processing,NLP)研究的主要是通过计算机算法来理解自然语言。对于自然语言来说,处理的数据主要就是人类的语言…...
Linux嵌入式完整镜像烧写到SD卡中的方法(包括对SD卡的介绍)
前言 本篇博文是博文https://blog.csdn.net/wenhao_ir/article/details/145547974 的分支,在本篇博文里我们主要是完成将镜像文件imx-image-full-imx6ull14x14evk-20201209093926.rootfs.wic烧写到SD卡中。 SD卡的介绍 SD卡(Secure Digital卡…...
vscode怎么更新github代码
vscode怎么更新github代码 打开终端: 在 VS Code 中,使用快捷键 Ctrl (Mac 上是 Cmd) 打开终端。 导航到项目目录: 确保你当前所在的终端目录是你的项目目录。如果不是,可以使用 cd 命令导航到项目目录,例如…...
回顾Golang的Channel与Select第二篇
深入掌握Go Channel与Select:从原理到生产级实践 一、Channel基础:不只是数据管道 1.1 通道的完整生命周期(可运行示例) package mainimport ("fmt""time" )func main() {// 创建缓冲通道ch : make(chan i…...
基于mediapipe深度学习的手势数字识别系统python源码
目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 程序运行配置环境: 人工智能算法python程序运行环境安装步骤整理-CSDN博客 3.部分核心…...
JS实现大文件切片上传以及断点续传
切片上传的原理是: 1.因为file对象的基类是blob,所以可以使用slice分割 2.将从input中获取的file对象使用slice进行分割,每5M一片 3.分别上传各个切片,等待切片上传完通知服务端合并(或者传每一片时把切片总数量也传…...
AI编程01-生成前/后端接口对表-豆包(或Deepseek+WPS的AI
前言: 做过全栈的工程师知道,如果一个APP的项目分别是前端/后端两个团队开发的话,那么原型设计之后,通过接口文档进行开发对接是非常必要的。 传统的方法是,大家一起定义一个接口文档,然后,前端和后端的工程师进行为何,现在AI的时代,是不是通过AI能协助呢,显然可以…...
1.6T光模块将成AI数据中心主流
2026年光通信模组的发展核心驱动力来自AI算力集群对超高带宽和极致能效的迫切需求,其技术演进呈现出高速率、高集成、低功耗、新架构的鲜明特征。光互连技术正从传统可插拔形态向更紧密的共封装光学(CPO) 和线性驱动可插拔光学(LP…...
HsMod炉石传说插件终极指南:55项功能完全解锁
HsMod炉石传说插件终极指南:55项功能完全解锁 【免费下载链接】HsMod Hearthstone Modification Based on BepInEx 项目地址: https://gitcode.com/GitHub_Trending/hs/HsMod HsMod是基于BepInEx框架开发的炉石传说多功能增强插件,为玩家提供游戏…...
FFmpeg GUI:3分钟搞定音视频处理,告别复杂命令行的图形化神器
FFmpeg GUI:3分钟搞定音视频处理,告别复杂命令行的图形化神器 【免费下载链接】ffmpegGUI ffmpeg GUI 项目地址: https://gitcode.com/gh_mirrors/ff/ffmpegGUI 还在为FFmpeg复杂的命令行参数而头疼吗?FFmpeg GUI为你带来了革命性的解…...
如何快速下载高解析无损音乐:Qobuz-DL完整指南
如何快速下载高解析无损音乐:Qobuz-DL完整指南 【免费下载链接】qobuz-dl A complete Lossless and Hi-Res music downloader for Qobuz 项目地址: https://gitcode.com/gh_mirrors/qo/qobuz-dl 想要构建个人高品质音乐库吗?Qobuz-DL这款开源音频…...
从人眼到算法:TV Line分辨率检测的实践与演进
1. TV Line检测技术的本质与演进 第一次接触TV Line检测是在2013年,当时我负责一款行车记录仪的摄像头模组验收。供应商提供的测试报告显示"分辨率达到1000线",但实际拍摄效果却模糊不清。这个矛盾让我开始深入研究TV Line检测的本质。 TV L…...
Python-TGPT:统一接口调用多AI模型,快速构建智能应用
1. 项目概述:一个让Python与AI对话的轻量级桥梁 如果你正在寻找一个能让你在Python脚本里快速、简单地调用主流大语言模型(LLM)API的工具,而不是被各种官方SDK复杂的初始化、错误处理和流式输出搞得头大,那么 Simatw…...
SITS大会技术社区交流活动实战指南(附2024最新人脉激活话术库)
更多请点击: https://intelliparadigm.com 第一章:SITS大会技术社区交流活动 SITS(Software Innovation & Technology Summit)大会作为国内聚焦开源协作与工程实践的年度技术盛会,其技术社区交流活动以“共建、共…...
Memgentic:基于遗传算法的智能内存管理优化实践
1. 项目概述:Memgentic是什么,以及它为何值得关注最近在开源社区里,一个名为“Memgentic”的项目引起了我的注意。这个项目由开发者Chariton-kyp创建,名字本身就很有意思,是“Memory”(记忆)和“…...
别再死记硬背了!用Python+NumPy手把手带你理解LTI系统的零极点与频率响应
用PythonNumPy实战解析LTI系统的零极点与频率响应 数字信号处理的理论常常让初学者感到抽象难懂,尤其是当教科书堆满数学公式时。但如果我们换一种方式——用代码和可视化来探索这些概念,一切突然变得清晰起来。本文将带你用Python和NumPy库,…...
告别转换失败!深度解析Allegro PCB导入PADS报错的5个常见原因及解决方法
Allegro转PADS报错全攻略:从底层原理到精准排错 最近在开源硬件社区看到一个典型案例:某团队将Allegro设计的六层工业控制板导入PADS时,反复出现"Allegro未做好迁移准备"的报错,导致项目延期两周。这让我想起五年前第一…...
