LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/145640762
GPRO,即 Group Relative Policy Optimization,分组相对的策略优化,是 PPO(Proximal Policy Optimization, 近端策略优化) 的优化版本,省略优化 评论家模型(Critic Model),用于估计价值(Value Function Model),降低模型训练的资源消耗。
GRPO 目标的工作原理如下:
- 为查询生成一组响应。
- 根据预定义的标准(例如准确性、格式),计算每个响应的奖励。
- 比较组内的反应以计算他们的相对优势。
- 更新策略以支持具有更高优势的响应,剪裁(clip)确保的稳定性。
- 规范更新以防止模型偏离基线太远。
GRPO 有效的原因:
- 无需评论:GRPO 依靠群体比较,避免对于单独评估者的需求,从而降低了计算成本。
- 稳定学习:剪裁(clip) 和 KL 正则化确保模型稳步改进,不会出现剧烈波动。
- 高效训练:通过关注相对性能,GRPO 非常适合推理等绝对评分困难的任务。
在 DeepSeekMath (2024.4) 中,使用 GPRO 代替 PPO。

回顾一下 PPO 模型的公式与框架,PPO 是先训练 奖励模型(RM),通过强化学习策略,将奖励模型的能力,学习到大语言模型中,同时,注意模型的输出符合之前的预期,不要偏离过远(KL Divergence)。即:
- RM(Reward Model, 奖励模型): m a x r ϕ { E ( x , y w i n , y l o s s ) ∼ D [ l o g σ ( r ϕ ( x , y w i n ) − r ϕ ( x , y l o s s ) ) ] } \underset{r_{\phi}}{max} \{ {E_{(x,y_{win},y_{loss}) \sim D}}[log \ \sigma(r_{\phi}(x,y_{win}) - r_{\phi}(x,y_{loss}))] \} rϕmax{E(x,ywin,yloss)∼D[log σ(rϕ(x,ywin)−rϕ(x,yloss))]}
- PPO(Proximal Policy Optimization, 近端策略优化): m a x π θ { E x ∼ D , y ∼ π θ ( y ∣ x ) [ r ϕ ( x , y ) ] − β D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] } \underset{\pi_{\theta}}{max} \{ E_{x \sim D,y \sim \pi_{\theta}(y|x)}[r_{\phi}(x,y)] - \beta D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] \} πθmax{Ex∼D,y∼πθ(y∣x)[rϕ(x,y)]−βDKL[πθ(y∣x)∣∣πref(y∣x)]}
- KL 散度(KL Divergence):
D K L [ π θ ( y ∣ x ) ∣ ∣ π r e f ( y ∣ x ) ] = π r e f ( y ∣ x ) l o g π r e f ( y ∣ x ) π θ ( y ∣ x ) = π r e f ( y ∣ x ) ( l o g π r e f ( y ∣ x ) − l o g π θ ( y ∣ x ) ) \begin{align} D_{KL}[\pi_{\theta}(y|x) || \pi_{ref}(y|x)] &= \pi_{ref}(y|x) \ log{\frac{\pi_{ref}(y|x)}{\pi_{\theta}(y|x)}} \\ &= \pi_{ref}(y|x) (log \pi_{ref}(y|x) - log\pi_{\theta}(y|x)) \end{align} DKL[πθ(y∣x)∣∣πref(y∣x)]=πref(y∣x) logπθ(y∣x)πref(y∣x)=πref(y∣x)(logπref(y∣x)−logπθ(y∣x))
其中,Actor 和 Critic 损失函数如下:
a [ i , j ] = r e t u r n s [ i , j ] − v a l u e s [ i , j ] L o s s a c t o r = − 1 M N ∑ i = 1 M ∑ j = 1 N a [ i , j ] × e x p ( l o g _ p r o b [ i , j ] − o l d _ l o g _ p r o b [ i , j ] ) L o s s c r i t i c = 1 2 M N ∑ i = 1 M ∑ j = 1 N ( v a l u e s [ i , j ] − r e t u r n s [ i , j ] ) 2 L o s s = L o s s a c t o r + 0.1 ∗ L o s s c r i t i c \begin{align} a[i,j] &= returns[i,j] - values[i,j] \\ Loss_{actor} &= -\frac{1}{MN} \sum_{i=1}^{M} \sum_{j=1}^{N} a[i,j] \times exp(log\_prob[i,j]-old\_log\_prob[i,j]) \\ Loss_{critic} &= \frac{1}{2MN} \sum_{i=1}^{M} \sum_{j=1}^{N} (values[i,j] - returns[i,j])^{2} \\ Loss & = Loss_{actor} + 0.1*Loss_{critic} \end{align} a[i,j]LossactorLosscriticLoss=returns[i,j]−values[i,j]=−MN1i=1∑Mj=1∑Na[i,j]×exp(log_prob[i,j]−old_log_prob[i,j])=2MN1i=1∑Mj=1∑N(values[i,j]−returns[i,j])2=Lossactor+0.1∗Losscritic
PPO 的奖励(Reward 计算),一般而言,超参数 β = 0.1 \beta=0.1 β=0.1:
r t = r ψ ( q , o ≤ t ) − β l o g ( π θ ( o t ∣ q , o < t ) π r e f ( o t ∣ q , o < t ) ) r_{t} = r_{\psi}(q,o_{\leq t}) - \beta log(\frac{\pi_{\theta}(o_{t}|q,o_{<t})}{\pi_{ref}(o_{t}|q,o_{<t})}) rt=rψ(q,o≤t)−βlog(πref(ot∣q,o<t)πθ(ot∣q,o<t))
在 PPO 中使用的价值函数(Critic Model),通常与策略模型(Policy Model)大小相当,带来内存和计算负担。在强化学习训练中,价值函数作为基线,以减少优势函数计算中的方差。然而,在 大语言模型(LLM) 的场景中,只有最后一个 Token 被奖励模型赋予奖励分数,使训练一个在每个标记处都准确的价值函数,变得复杂。GRPO 无需像 PPO 那样,使用额外的近似价值函数,而是使用同一问题产生的多个采样输出的平均奖励,作为基线。
GRPO 使用基于 组相对(Group Relative) 的优势计算方式,与奖励模型比较特性一致,因为奖励模型通常是在同一问题上不同输出之间的比较数据集上进行训练的。同时,GRPO 没有在 奖励(Reward) 中加入 KL 惩罚,而是直接将训练策略与参考策略之间的KL散度添加到损失函数中,从而避免了在计算优势时增加复杂性。
GPRO 的公式, Q Q Q 表示 Query,即输入的问题,采样出问题 q q q,推理大模型,输出 G G G 个输出 o i o_{i} oi:
J G R P O ( θ ) = E [ q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) ] L o s s = 1 G ∑ i = 1 G ( m i n ( ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) ) A i , c l i p ( π θ ( o i ∣ q ) π θ o l d ( o i ∣ q ) , 1 − ϵ , 1 + ϵ ) A i − β D K L ( π θ ∣ ∣ π r e f ) ) D K L ( π θ ∣ ∣ π r e f ) = π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − l o g π r e f ( o i ∣ q ) π θ ( o i ∣ q ) − 1 A i = r i − m e a n ( { r 1 , r 2 , … , r G } ) s t d ( { r 1 , r 2 , … , r G } ) \begin{align} J_{GRPO}(\theta) &= \mathbb{E}[q \sim P(Q), \{{o_{i}}\}_{i=1}^{G} \sim \pi_{\theta_{old}}(O|q)] \\ Loss &= \frac{1}{G}\sum_{i=1}^{G}(min((\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)})A_{i}, clip(\frac{\pi_{\theta}(o_{i}|q)}{\pi_{\theta_{old}}(o_{i}|q)},1-\epsilon,1+\epsilon)A_{i}-\beta \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref})) \\ \mathbb{D}_{KL}(\pi_{\theta}||\pi_{ref}) &= \frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - log\frac{\pi_{ref}(o_{i}|q)}{\pi_{\theta}(o_{i}|q)} - 1 \\ A_{i} &= \frac{r_{i}-mean(\{r_{1},r_{2},\ldots,r_{G}\})}{std(\{r_{1},r_{2},\ldots,r_{G}\})} \end{align} JGRPO(θ)LossDKL(πθ∣∣πref)Ai=E[q∼P(Q),{oi}i=1G∼πθold(O∣q)]=G1i=1∑G(min((πθold(oi∣q)πθ(oi∣q))Ai,clip(πθold(oi∣q)πθ(oi∣q),1−ϵ,1+ϵ)Ai−βDKL(πθ∣∣πref))=πθ(oi∣q)πref(oi∣q)−logπθ(oi∣q)πref(oi∣q)−1=std({r1,r2,…,rG})ri−mean({r1,r2,…,rG})
GRPO 的 KL 散度,使用蒙特卡洛(Monte-Carlo) 近似计算 KL散度(Kullback-Leibler Divergence),结果始终为正数。
参考源码,TRL - GRPO:
# Advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# KL 散度
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
# 期望
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
# 联合 loss
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
# mask loss
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
GRPO 的 训练源码:trl/trainer/grpo_trainer.py

PPO 的伪码流程:
policy_model = load_model()
ref_model = policy_model.copy() # 不更新
critic_model = load_reward_model(only_last=False)
reward_model = critic_mode.copy() # 不更新for i in steps:# 1. 采样阶段prompts = sample_prompt()# old_log_probs[i][j](from policy_model), old_values[i][j](from critic_model)responses, old_log_probs, old_values = respond(policy_model, critic_model, prompts)# 2. 反馈阶段scores = reward_model(prompts, responses)# ref_log_probs[i][j](from ref_model)ref_log_probs = analyze_responses(ref_model, prompts, responses) # ref logps# rewards[i][j] = scores[i] - (old_log_probs[i][j] - ref_log_prob[i][j])rewards = reward_func(scores, old_log_probs, ref_log_probs) # 奖励计算# advantages[i][j] = rewards[i][j] - old_values[i][j] advantages = advantage_func(rewards, old_values) # 奖励(r)-价值(v)=优势(a)# 3. 学习阶段for j in ppo_epochs: # 多次更新学习,逐渐靠近奖励log_probs = analyze_responses(policy_model, prompts, responses)values = analyze_responses(critic_model, prompts, responses)# 更新 actor(policy) 模型,学习更新的差异,advantages[i][j]越大,强化动作actor_loss = actor_loss_func(advantages, old_log_probs, log_probs) critic_loss = critic_loss_func(rewards, values) # 更新 critic 模型loss = actor_loss + 0.1 * critic_loss # 更新train(loss, policy_model.parameters(), critic_model.parameters()) # 参数
参考 知乎 - 图解大模型RLHF系列之:人人都能看懂的PPO原理与源码解读
KL 散度的实现,如下:
import torch
import torch.nn.functional as F
# 假设我们有两个概率分布 P 和 Q
P = torch.tensor([0.1, 0.2, 0.7]) # 参考的、真实的
Q = torch.tensor([0.2, 0.3, 0.5]) # 模型生成的
# 计算 Q 的对数概率
log_Q = torch.log(Q)
# 使用 PyTorch 的 kl_div 函数计算 KL 散度
kl_divergence = F.kl_div(log_Q, P, reduction='sum') # 注意先Q后P
print(f"KL Div (PyTorch): {kl_divergence}")
log_P = torch.log(P)
kl_elementwise = P * (log_P - log_Q)
# 对所有元素求和,得到 KL 散度
kl_divergence = torch.sum(kl_elementwise)
参考:
- 知乎 - GRPO: Group Relative Policy Optimization
- GitHub - GRPO Trainer
- Medium - The Math Behind DeepSeek: A Deep Dive into Group Relative Policy Optimization (GRPO)
相关文章:
LLM - 理解 DeepSeek 的 GPRO (分组相对策略优化) 公式与源码 教程(2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/145640762 GPRO,即 Group Relative Policy Optimization,分组相对的策略优化,是 PPO(Proximal Policy Optimiz…...
Github 2025-02-14 Java开源项目日报 Top10
根据Github Trendings的统计,今日(2025-02-14统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目10C#项目1Guava: 谷歌Java核心库 创建周期:3725 天开发语言:Java协议类型:Apache License 2.0Star数量:49867 个Fork数量:10822 次…...
DeepSeek赋能制造业:图表可视化从入门到精通
一、企业数据可视化之困 在数字化浪潮席卷全球的当下,商贸流通企业作为经济活动的关键枢纽,每天都在与海量数据打交道。从商品的采购、库存管理,到销售渠道的拓展、客户关系的维护,各个环节都源源不断地产生数据。这些数据犹如一座蕴含巨大价值的宝藏,然而,如何挖掘并利用…...
Python爬虫技术
Python爬虫技术凭借其高效便捷的特性,已成为数据采集领域的主流工具。以下从技术优势、核心实现、工具框架、反爬策略及注意事项等方面进行系统阐述: 一、Python爬虫的核心优势 语法简洁与开发效率高 Python的语法简洁易读,配合丰富的第三方库…...
C++Primer学习(4.6成员访问运算符)
4.6成员访问运算符 点运算符和箭头运算符都可用于访问成员,其中,点运算符获取类对象的一个成员;箭头运算符与点运算符有关,表达式 ptr->mem等价于(* ptr).mem: string sl"a string",*p &s1; auto ns1.size();//运行string对…...
c++14之std::make_unique
基础介绍 虽然在c11版本std::unique_ptr<T>已经引入,但是在c14版本引入之前,std::unique_ptr<T>的创建还是通过new操作符来完成的。在c14版本已经引入了类似make_shared的std::make_unique,目的是提供更加安全的方法创建std::un…...
服务器linux操作系统安全加固
一、系统更新与补丁管理 更新系统sudo yum update -y # 更新所有软件包 sudo yum install epel-release -y # 安装EPEL扩展源启用自动安全更新sudo yum install yum-cron -y sudo systemctl enable yum-cron sudo systemctl start yum-cron配置 /etc/yum/yum-cron.con…...
原生Three.js 和 Cesium.js 案例 。 智慧城市 数字孪生常用功能列表
对于大多数的开发者来言,看了很多文档可能遇见不到什么有用的,就算有用从文档上看,把代码复制到自己的本地大多数也是不能用的,非常浪费时间和学习成本, 尤其是three.js , cesium.js 这种难度较高ÿ…...
Node.js中Express框架使用指南:从入门到企业级实践
目录 一、Express快速入门 1. 项目初始化 2. 基础服务搭建 3. 添加热更新 二、核心功能详解 1. 路由系统 动态路由参数 路由模块化 2. 中间件机制 自定义中间件 常用官方中间件 3. 模板引擎集成 三、企业级最佳实践 1. 项目结构规范 2. 错误处理方案 3. 安全防护…...
spring 学习 (注解)
目录 前言 常用的注解 须知 1 Conponent注解 demo(案例) 2 ControllerServiceRepository demo(案例) 3 ScopeLazyPostConstructPreDestroy demo(案例) 4 ValueAutowiredQualifierResource demo(案例) 5 Co…...
计算机等级考试——计算机三级——网络技术部分
计算机三级——网络技术部分 一、外部网关协议BGP考点二、IPS入侵防护系统考点三、OSPF协议考点四、弹性分组环——RPR技术 一、外部网关协议BGP考点 高频考点,中考次数:25次 这类知识采用背诵的方式,可以更快速地备考。 BGP是边界网关协议&…...
新版电脑通过wepe安装系统
官方下载链接 WIN10下载 WIN11下载 微PE 启动盘制作 1:选择启动盘的设备 2:选择对应的U盘设备,点击安装就可以,建议大于8g 3:在上方链接下载需要安装的程序包,放入启动盘,按需 更新系统 …...
oracle中decode怎么转换成pg
对于 PostgreSQL 中的 Oracle DECODE 函数,可以使用 CASE 表达式或联合。CASE 表达式根据条件返回第一个匹配的结果,语法为:CASE WHEN 条件 THEN 结果 ELSE 结果 END。联合通过 UNION ALL 操作符组合多个 SELECT 语句,返回一个包含…...
【NLP】循环神经网络RNN
目录 一、词嵌入层 二、循环网络层 2.1 RNN网络原理 2.2 Pytorch RNN API 自然语言处理(Nature language Processing,NLP)研究的主要是通过计算机算法来理解自然语言。对于自然语言来说,处理的数据主要就是人类的语言…...
Linux嵌入式完整镜像烧写到SD卡中的方法(包括对SD卡的介绍)
前言 本篇博文是博文https://blog.csdn.net/wenhao_ir/article/details/145547974 的分支,在本篇博文里我们主要是完成将镜像文件imx-image-full-imx6ull14x14evk-20201209093926.rootfs.wic烧写到SD卡中。 SD卡的介绍 SD卡(Secure Digital卡…...
vscode怎么更新github代码
vscode怎么更新github代码 打开终端: 在 VS Code 中,使用快捷键 Ctrl (Mac 上是 Cmd) 打开终端。 导航到项目目录: 确保你当前所在的终端目录是你的项目目录。如果不是,可以使用 cd 命令导航到项目目录,例如…...
回顾Golang的Channel与Select第二篇
深入掌握Go Channel与Select:从原理到生产级实践 一、Channel基础:不只是数据管道 1.1 通道的完整生命周期(可运行示例) package mainimport ("fmt""time" )func main() {// 创建缓冲通道ch : make(chan i…...
基于mediapipe深度学习的手势数字识别系统python源码
目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 程序运行配置环境: 人工智能算法python程序运行环境安装步骤整理-CSDN博客 3.部分核心…...
JS实现大文件切片上传以及断点续传
切片上传的原理是: 1.因为file对象的基类是blob,所以可以使用slice分割 2.将从input中获取的file对象使用slice进行分割,每5M一片 3.分别上传各个切片,等待切片上传完通知服务端合并(或者传每一片时把切片总数量也传…...
AI编程01-生成前/后端接口对表-豆包(或Deepseek+WPS的AI
前言: 做过全栈的工程师知道,如果一个APP的项目分别是前端/后端两个团队开发的话,那么原型设计之后,通过接口文档进行开发对接是非常必要的。 传统的方法是,大家一起定义一个接口文档,然后,前端和后端的工程师进行为何,现在AI的时代,是不是通过AI能协助呢,显然可以…...
label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
可靠性+灵活性:电力载波技术在楼宇自控中的核心价值
可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...
Qt Widget类解析与代码注释
#include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this); }Widget::~Widget() {delete ui; }//解释这串代码,写上注释 当然可以!这段代码是 Qt …...
《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...
OkHttp 中实现断点续传 demo
在 OkHttp 中实现断点续传主要通过以下步骤完成,核心是利用 HTTP 协议的 Range 请求头指定下载范围: 实现原理 Range 请求头:向服务器请求文件的特定字节范围(如 Range: bytes1024-) 本地文件记录:保存已…...
ServerTrust 并非唯一
NSURLAuthenticationMethodServerTrust 只是 authenticationMethod 的冰山一角 要理解 NSURLAuthenticationMethodServerTrust, 首先要明白它只是 authenticationMethod 的选项之一, 并非唯一 1 先厘清概念 点说明authenticationMethodURLAuthenticationChallenge.protectionS…...
从零实现STL哈希容器:unordered_map/unordered_set封装详解
本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说,直接开始吧! 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...
分布式增量爬虫实现方案
之前我们在讨论的是分布式爬虫如何实现增量爬取。增量爬虫的目标是只爬取新产生或发生变化的页面,避免重复抓取,以节省资源和时间。 在分布式环境下,增量爬虫的实现需要考虑多个爬虫节点之间的协调和去重。 另一种思路:将增量判…...
【JVM】Java虚拟机(二)——垃圾回收
目录 一、如何判断对象可以回收 (一)引用计数法 (二)可达性分析算法 二、垃圾回收算法 (一)标记清除 (二)标记整理 (三)复制 (四ÿ…...
