LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145640762
GPRO,即 Group Relative Policy Optimization,分组相对的策略优化,是 PPO(Proximal Policy Optimization, 近端策略优化) 的优化版本,省略优化 评论家模型(Critic Model),用于估计价值(Value Function Model),降低模型训练的资源消耗。
GRPO 目标的工作原理如下:
- 为查询生成一组响应。
- 根据预定义的标准(例如准确性、格式),计算每个响应的奖励。
- 比较组内的反应以计算他们的相对优势。
- 更新策略以支持具有更高优势的响应,剪裁(clip)确保的稳定性。
- 规范更新以防止模型偏离基线太远。
GRPO 有效的原因:
- 无需评论:GRPO 依靠群体比较,避免对于单独评估者的需求,从而降低了计算成本。
- 稳定学习:剪裁(clip) 和 KL 正则化确保模型稳步改进,不会出现剧烈波动。
- 高效训练:通过关注相对性能,GRPO 非常适合推理等绝对评分困难的任务。
在 DeepSeekMath (2024.4) 中,使用 GPRO 代替 PPO。
回顾一下 PPO 模型的公式与框架,PPO 是先训练 奖励模型(RM),通过强化学习策略,将奖励模型的能力,学习到大语言模型中,同时,注意模型的输出符合之前的预期,不要偏离过远(KL Divergence)。即:
- RM(Reward Model, 奖励模型): m a x r ϕ { E ( x , y w i n , y l o s s ) ∼ D [ l o g σ ( r ϕ ( x , y w i n ) − r ϕ ( x , y l o s s ) ) ] } \underset{r_{\phi}}{max} \{ {E_{(x,y_{win},y_{loss}) \sim D}}[log \ \sigma(r_{\phi}(x,y_{win}) - r_{\phi}(x,y_{loss}))] \} rϕmax{E(x,ywin,yloss)∼D[log σ(rϕ(x,ywin)−rϕ(x,yloss))]}
- PPO(Proximal Policy Optimization, 近端策略优化): m a x π θ { E x ∼ D , y ∼ π θ ( y ∣ x ) [ r ϕ ( x , y ) ] − β D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] } \underset{\pi_{\theta}}{max} \{ E_{x \sim D,y \sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)] - \beta D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] \} πθmax{Ex∼D,y∼πθ(y∣x)[rϕ(x,y)]−βDKL[πθ(y∣x)∣∣πref(y∣x)]}
- KL 散度(KL Divergence):
D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] = π r e f ( y ∣ x ) l o g π r e f ( y ∣ x ) π θ ( y ∣ x ) = π r e f ( y ∣ x ) ( l o g π r e f ( y ∣ x ) − l o g π θ ( y ∣ x ) ) \begin{align} D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] &= \pi_{ref}(y|x) \ log{\frac{\pi_{ref}(y|x)}{\pi_{\theta}(y|x)}} \\ &= \pi_{ref}(y|x) (log \pi_{ref}(y|x) - log\pi_{\theta}(y|x)) \end{align} DKL[πθ(y∣x)∣∣πref(y∣x)]=πref(y∣x) logπθ(y∣x)πref(y∣x)=πref(y∣x)(logπref(y∣x)−logπθ(y∣x))
其中,Actor 和 Critic 损失函数如下:
a [ i , j ] = r e t u r n s [ i , j ] − v a l u e s [ i , j ] L o s s a c t o r = − 1 M N ∑ i = 1 M ∑ j = 1 N a [ i , j ] × e x p ( l o g _ p r o b [ i , j ] − o l d _ l o g _ p r o b [ i , j ] ) L o s s c r i t i c = 1 2 M N ∑ i = 1 M ∑ j = 1 N ( v a l u e s [ i , j ] − r e t u r n s [ i , j ] ) 2 L o s s = L o s s a c t o r + 0.1 ∗ L o s s c r i t i c \begin{align} a[i,j] &= returns[i,j] - values[i,j] \\ Loss_{actor} &= -\frac{1}{MN} \sum_{i=1}^{M} \sum_{j=1}^{N} a[i,j] \times exp(log\_prob[i,j]-old\_log\_prob[i,j]) \\ Loss_{critic} &= \frac{1}{2MN} \sum_{i=1}^{M} \sum_{j=1}^{N} (values[i,j] - returns[i,j])^{2} \\ Loss & = Loss_{actor} + 0.1*Loss_{critic} \end{align} a[i,j]LossactorLosscriticLoss=returns[i,j]−values[i,j]=−MN1i=1∑Mj=1∑Na[i,j]×exp(log_prob[i,j]−old_log_prob[i,j])=2MN1i=1∑Mj=1∑N(values[i,j]−returns[i,j])2=Lossactor+0.1∗Losscritic
PPO 的奖励(Reward 计算),一般而言,超参数 β = 0.1 \beta=0.1 β=0.1:
r t = r ψ ( q , o ≤ t ) − β l o g ( π θ ( o t ∣ q , o < t ) π r e f ( o t ∣ q , o < t ) ) r_{t} = r_{\psi}(q,o_{\leq t}) - \beta log(\frac{\pi_{\theta}(o_{t}|q,o_{<t})}{\pi_{ref}(o_{t}|q,o_{<t})}) rt=rψ(q,o≤t)−βlog(πref(ot∣q,o<t)πθ(ot∣q,o<t))
在 PPO 中使用的价值函数(Critic Model),通常与策略模型(Policy Model)大小相当,带来内存和计算负担。在强化学习训练中,价值函数作为基线,以减少优势函数计算中的方差。然而,在 大语言模型(LLM) 的场景中,只有最后一个 Token 被奖励模型赋予奖励分数,使训练一个在每个标记处都准确的价值函数,变得复杂。GRPO 无需像 PPO 那样,使用额外的近似价值函数,而是使用同一问题产生的多个采样输出的平均奖励,作为基线。
GRPO 使用基于 组相对(Group Relative) 的优势计算方式,与奖励模型比较特性一致,因为奖励模型通常是在同一问题上不同输出之间的比较数据集上进行训练的。同时,GRPO 没有在 奖励(Reward) 中加入 KL 惩罚,而是直接将训练策略与参考策略之间的KL散度添加到损失函数中,从而避免了在计算优势时增加复杂性。
GPRO 的公式, Q Q Q 表示 Query,即输入的问题,采样出问题 q q q,推理大模型,输出 G G G 个输出 o i o_{i} oi:
J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) ] L o s s = 1 G ∑ i = 1 G ( m i n ( ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) ) A i , c l i p ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) , 1 − ϵ , 1 + ϵ ) A i − β D K L ( π θ ∣ ∣ π r e f ) ) D K L ( π θ ∣ ∣ π r e f ) = π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − l o g π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − 1 A i = r i − m e a n ( { r 1 , r 2 , … , r G } ) s t d ( { r 1 , r 2 , … , r G } ) \begin{align} J_{GRPO}(\theta) &= \mathbb{E}[q \sim P(Q), \{{o_{i}}\}_{i=1}^{G} \sim \pi_{\theta_{old}}(O|q)] \\ Loss &= \frac{1}{G}\sum_{i=1}^{G}(min((\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)})A_{i}, clip(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)},1-\epsilon,1+\epsilon)A_{i}-\beta \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref})) \\ \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref}) &= \frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - log\frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - 1 \\ A_{i} &= \frac{r_{i}-mean(\{r_{1},r_{2},\ldots,r_{G}\})}{std(\{r_{1},r_{2},\ldots,r_{G}\})} \end{align} JGRPO(θ)LossDKL(πθ∣∣πref)Ai=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]=G1i=1∑G(min((πθold(oi∣q)πθ(oi∣q))Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ϵ,1+ϵ)Ai−βDKL(πθ∣∣πref))=πθ(oi∣q)πref(oi∣q)−logπθ(oi∣q)πref(oi∣q)−1=std({r1,r2,…,rG})ri−mean({r1,r2,…,rG})
GRPO 的 KL 散度,使用蒙特卡洛(Monte-Carlo) 近似计算 KL散度(Kullback-Leibler Divergence),结果始终为正数。
参考源码,TRL - GRPO:
# Advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# KL 散度
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# 期望
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# 联合 loss
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# mask loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
GRPO 的 训练源码:trl/trainer/grpo_trainer.py
PPO 的伪码流程:
policy_model = load_model()
ref_model = policy_model.copy() # 不更新
critic_model = load_reward_model(only_last=False)
reward_model = critic_mode.copy() # 不更新for i in steps:# 1. 采样阶段prompts = sample_prompt()# old_log_probs[i][j](from policy_model), old_values[i][j](from critic_model)responses, old_log_probs, old_values = respond(policy_model, critic_model, prompts)# 2. 反馈阶段scores = reward_model(prompts, responses)# ref_log_probs[i][j](from ref_model)ref_log_probs = analyze_responses(ref_model, prompts, responses) # ref logps# rewards[i][j] = scores[i] - (old_log_probs[i][j] - ref_log_prob[i][j])rewards = reward_func(scores, old_log_probs, ref_log_probs) # 奖励计算# advantages[i][j] = rewards[i][j] - old_values[i][j] advantages = advantage_func(rewards, old_values) # 奖励(r)-价值(v)=优势(a)# 3. 学习阶段for j in ppo_epochs: # 多次更新学习,逐渐靠近奖励log_probs = analyze_responses(policy_model, prompts, responses)values = analyze_responses(critic_model, prompts, responses)# 更新 actor(policy) 模型,学习更新的差异,advantages[i][j]越大,强化动作actor_loss = actor_loss_func(advantages, old_log_probs, log_probs) critic_loss = critic_loss_func(rewards, values) # 更新 critic 模型loss = actor_loss + 0.1 * critic_loss # 更新train(loss, policy_model.parameters(), critic_model.parameters()) # 参数
参考 知乎 - 图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读
KL 散度的实现,如下:
import torch
import torch.nn.functional as F
# 假设我们有两个概率分布 P 和 Q
P = torch.tensor([0.1, 0.2, 0.7]) # 参考的、真实的
Q = torch.tensor([0.2, 0.3, 0.5]) # 模型生成的
# 计算 Q 的对数概率
log_Q = torch.log(Q)
# 使用 PyTorch 的 kl_div 函数计算 KL 散度
kl_divergence = F.kl_div(log_Q, P, reduction='sum') # 注意先Q后P
print(f"KL Div (PyTorch): {kl_divergence}")
log_P = torch.log(P)
kl_elementwise = P * (log_P - log_Q)
# 对所有元素求和,得到 KL 散度
kl_divergence = torch.sum(kl_elementwise)
参考:
- 知乎 - GRPO: Group Relative Policy Optimization
- GitHub - GRPO Trainer
- Medium - The Math Behind DeepSeek: A Deep Dive into Group Relative Policy Optimization (GRPO)
相关文章:

LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/145640762 GPRO,即 Group Relative Policy Optimization,分组相对的策略优化,是 PPO(Proximal Policy Optimiz…...
Github 2025-02-14 Java开源项目日报 Top10
根据Github Trendings的统计,今日(2025-02-14统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目10C#项目1Guava: 谷歌Java核心库 创建周期:3725 天开发语言:Java协议类型:Apache License 2.0Star数量:49867 个Fork数量:10822 次…...
DeepSeek赋能制造业:图表可视化从入门到精通
一、企业数据可视化之困 在数字化浪潮席卷全球的当下,商贸流通企业作为经济活动的关键枢纽,每天都在与海量数据打交道。从商品的采购、库存管理,到销售渠道的拓展、客户关系的维护,各个环节都源源不断地产生数据。这些数据犹如一座蕴含巨大价值的宝藏,然而,如何挖掘并利用…...
Python爬虫技术
Python爬虫技术凭借其高效便捷的特性,已成为数据采集领域的主流工具。以下从技术优势、核心实现、工具框架、反爬策略及注意事项等方面进行系统阐述: 一、Python爬虫的核心优势 语法简洁与开发效率高 Python的语法简洁易读,配合丰富的第三方库…...
C++Primer学习(4.6成员访问运算符)
4.6成员访问运算符 点运算符和箭头运算符都可用于访问成员,其中,点运算符获取类对象的一个成员;箭头运算符与点运算符有关,表达式 ptr->mem等价于(* ptr).mem: string sl"a string",*p &s1; auto ns1.size();//运行string对…...
c++14之std::make_unique
基础介绍 虽然在c11版本std::unique_ptr<T>已经引入,但是在c14版本引入之前,std::unique_ptr<T>的创建还是通过new操作符来完成的。在c14版本已经引入了类似make_shared的std::make_unique,目的是提供更加安全的方法创建std::un…...
服务器linux操作系统安全加固
一、系统更新与补丁管理 更新系统sudo yum update -y # 更新所有软件包 sudo yum install epel-release -y # 安装EPEL扩展源启用自动安全更新sudo yum install yum-cron -y sudo systemctl enable yum-cron sudo systemctl start yum-cron配置 /etc/yum/yum-cron.con…...

原生Three.js 和 Cesium.js 案例 。 智慧城市 数字孪生常用功能列表
对于大多数的开发者来言,看了很多文档可能遇见不到什么有用的,就算有用从文档上看,把代码复制到自己的本地大多数也是不能用的,非常浪费时间和学习成本, 尤其是three.js , cesium.js 这种难度较高ÿ…...
Node.js中Express框架使用指南:从入门到企业级实践
目录 一、Express快速入门 1. 项目初始化 2. 基础服务搭建 3. 添加热更新 二、核心功能详解 1. 路由系统 动态路由参数 路由模块化 2. 中间件机制 自定义中间件 常用官方中间件 3. 模板引擎集成 三、企业级最佳实践 1. 项目结构规范 2. 错误处理方案 3. 安全防护…...

spring 学习 (注解)
目录 前言 常用的注解 须知 1 Conponent注解 demo(案例) 2 ControllerServiceRepository demo(案例) 3 ScopeLazyPostConstructPreDestroy demo(案例) 4 ValueAutowiredQualifierResource demo(案例) 5 Co…...
计算机等级考试——计算机三级——网络技术部分
计算机三级——网络技术部分 一、外部网关协议BGP考点二、IPS入侵防护系统考点三、OSPF协议考点四、弹性分组环——RPR技术 一、外部网关协议BGP考点 高频考点,中考次数:25次 这类知识采用背诵的方式,可以更快速地备考。 BGP是边界网关协议&…...

新版电脑通过wepe安装系统
官方下载链接 WIN10下载 WIN11下载 微PE 启动盘制作 1:选择启动盘的设备 2:选择对应的U盘设备,点击安装就可以,建议大于8g 3:在上方链接下载需要安装的程序包,放入启动盘,按需 更新系统 …...
oracle中decode怎么转换成pg
对于 PostgreSQL 中的 Oracle DECODE 函数,可以使用 CASE 表达式或联合。CASE 表达式根据条件返回第一个匹配的结果,语法为:CASE WHEN 条件 THEN 结果 ELSE 结果 END。联合通过 UNION ALL 操作符组合多个 SELECT 语句,返回一个包含…...
【NLP】循环神经网络RNN
目录 一、词嵌入层 二、循环网络层 2.1 RNN网络原理 2.2 Pytorch RNN API 自然语言处理(Nature language Processing,NLP)研究的主要是通过计算机算法来理解自然语言。对于自然语言来说,处理的数据主要就是人类的语言…...

Linux嵌入式完整镜像烧写到SD卡中的方法(包括对SD卡的介绍)
前言 本篇博文是博文https://blog.csdn.net/wenhao_ir/article/details/145547974 的分支,在本篇博文里我们主要是完成将镜像文件imx-image-full-imx6ull14x14evk-20201209093926.rootfs.wic烧写到SD卡中。 SD卡的介绍 SD卡(Secure Digital卡…...
vscode怎么更新github代码
vscode怎么更新github代码 打开终端: 在 VS Code 中,使用快捷键 Ctrl (Mac 上是 Cmd) 打开终端。 导航到项目目录: 确保你当前所在的终端目录是你的项目目录。如果不是,可以使用 cd 命令导航到项目目录,例如…...
回顾Golang的Channel与Select第二篇
深入掌握Go Channel与Select:从原理到生产级实践 一、Channel基础:不只是数据管道 1.1 通道的完整生命周期(可运行示例) package mainimport ("fmt""time" )func main() {// 创建缓冲通道ch : make(chan i…...

基于mediapipe深度学习的手势数字识别系统python源码
目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 程序运行配置环境: 人工智能算法python程序运行环境安装步骤整理-CSDN博客 3.部分核心…...

JS实现大文件切片上传以及断点续传
切片上传的原理是: 1.因为file对象的基类是blob,所以可以使用slice分割 2.将从input中获取的file对象使用slice进行分割,每5M一片 3.分别上传各个切片,等待切片上传完通知服务端合并(或者传每一片时把切片总数量也传…...

AI编程01-生成前/后端接口对表-豆包(或Deepseek+WPS的AI
前言: 做过全栈的工程师知道,如果一个APP的项目分别是前端/后端两个团队开发的话,那么原型设计之后,通过接口文档进行开发对接是非常必要的。 传统的方法是,大家一起定义一个接口文档,然后,前端和后端的工程师进行为何,现在AI的时代,是不是通过AI能协助呢,显然可以…...

深入剖析AI大模型:大模型时代的 Prompt 工程全解析
今天聊的内容,我认为是AI开发里面非常重要的内容。它在AI开发里无处不在,当你对 AI 助手说 "用李白的风格写一首关于人工智能的诗",或者让翻译模型 "将这段合同翻译成商务日语" 时,输入的这句话就是 Prompt。…...
java 实现excel文件转pdf | 无水印 | 无限制
文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...
解锁数据库简洁之道:FastAPI与SQLModel实战指南
在构建现代Web应用程序时,与数据库的交互无疑是核心环节。虽然传统的数据库操作方式(如直接编写SQL语句与psycopg2交互)赋予了我们精细的控制权,但在面对日益复杂的业务逻辑和快速迭代的需求时,这种方式的开发效率和可…...

CentOS下的分布式内存计算Spark环境部署
一、Spark 核心架构与应用场景 1.1 分布式计算引擎的核心优势 Spark 是基于内存的分布式计算框架,相比 MapReduce 具有以下核心优势: 内存计算:数据可常驻内存,迭代计算性能提升 10-100 倍(文档段落:3-79…...
在四层代理中还原真实客户端ngx_stream_realip_module
一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡(如 HAProxy、AWS NLB、阿里 SLB)发起上游连接时,将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后,ngx_stream_realip_module 从中提取原始信息…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互
引擎版本: 3.8.1 语言: JavaScript/TypeScript、C、Java 环境:Window 参考:Java原生反射机制 您好,我是鹤九日! 回顾 在上篇文章中:CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...
WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)
一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解,适合用作学习或写简历项目背景说明。 🧠 一、概念简介:Solidity 合约开发 Solidity 是一种专门为 以太坊(Ethereum)平台编写智能合约的高级编…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)
RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发,后来由Pivotal Software Inc.(现为VMware子公司)接管。RabbitMQ 是一个开源的消息代理和队列服务器,用 Erlang 语言编写。广泛应用于各种分布…...

在Mathematica中实现Newton-Raphson迭代的收敛时间算法(一般三次多项式)
考察一般的三次多项式,以r为参数: p[z_, r_] : z^3 (r - 1) z - r; roots[r_] : z /. Solve[p[z, r] 0, z]; 此多项式的根为: 尽管看起来这个多项式是特殊的,其实一般的三次多项式都是可以通过线性变换化为这个形式…...
【Elasticsearch】Elasticsearch 在大数据生态圈的地位 实践经验
Elasticsearch 在大数据生态圈的地位 & 实践经验 1.Elasticsearch 的优势1.1 Elasticsearch 解决的核心问题1.1.1 传统方案的短板1.1.2 Elasticsearch 的解决方案 1.2 与大数据组件的对比优势1.3 关键优势技术支撑1.4 Elasticsearch 的竞品1.4.1 全文搜索领域1.4.2 日志分析…...