当前位置: 首页 > news >正文

【复现DeepSeek-R1之Open R1实战】系列4:跑通GRPO!

目录

  • 1 配置环境
  • 2 训练
    • 2.1 命令和配置参数
    • 2.2 num_generations
      • 2.2.1 参数定义
      • 2.2.2 参数含义
      • 2.2.3 示例
      • 2.2.4 使用场景
      • 2.2.5 示例代码
    • 2.3 显存占用和耗时
  • 3 结果

1 配置环境

关于环境配置,可以参考这篇博文:【复现DeepSeek-R1之Open R1实战】系列1:跑通SFT(一步步操作,手把手教学)

关于flash-attention依赖库的安装问题,运行以下命令,等待一小时左右,依赖库就安装成功了:

pip install flash-attn --no-cache-dir

2 训练

2.1 命令和配置参数

训练的命令如下,和SFT差不多:

ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \--num_processes=7 src/open_r1/grpo.py \--config /nfs/ofs-902-1/fusion/zhongyudong/open-r1/recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml

我们需要修改config配置文件:recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml,主要是将和Huggingface的链接关掉,修改模型路径、数据集路径、GPU个数(num_processes=GPU个数-1,因为vLLM使用了一张卡)。

在训练过程中,我发现torch的DDP不稳定,容易接收不到Worker的信号导致训练失败,所以保存策略改成了每步都保存(save_strategy: “steps”)。

完整的配置如下:

# Model arguments
model_name_or_path: /nfs/ofs-902-1/pnc/huggingface_hub/Qwen/Qwen2.5-1.5B-Instruct
# model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2# Data training arguments
dataset_name: /nfs/ofs-902-1/fusion/zhongyudong/open-r1/datas/NuminaMath-TIR/data
dataset_configs:
- all
# Num processes is less by 1 as vLLM is using 1 GPU
num_processes: 7# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:use_reentrant: false
# hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO
# hub_strategy: every_save
learning_rate: 2.0e-05
log_level: info
logging_steps: 5
logging_strategy: steps
lr_scheduler_type: cosine
max_prompt_length: 512
max_completion_length: 1024
max_steps: -1
num_generations: 7
num_train_epochs: 1
output_dir: /nfs/ofs-902-1/fusion/zhongyudong/open-r1/outputs/Qwen2.5-1.5B-Open-R1-GRPO
overwrite_output_dir: true
per_device_eval_batch_size: 32
per_device_train_batch_size: 16
push_to_hub: false
# report_to:
# - wandb
save_strategy: "steps"
seed: 42
warmup_ratio: 0.1

重点解释一下num_generations这个参数,主要是控制每个提示(Prompt)生成的样本数量。

2.2 num_generations

参数 num_generations 用于指定每个提示(prompt)生成的样本数量。这个参数在生成模型中非常常见,特别是在文本生成、对话系统或其他需要从模型中采样多个输出的任务中。以下是对该参数及其使用场景的详细解释:

2.2.1 参数定义

num_generations (`int` or `None`, *optional*, defaults to `8`):Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)must be divisible by this value.
  • 类型: 可以是整数(int)或 None
  • 默认值: 默认为 8
  • 可选性: 是一个可选参数。

2.2.2 参数含义

  1. 生成数量:

    • num_generations 指定了对于每一个输入的提示(prompt),模型将生成多少个不同的输出样本。
    • 例如,如果你设置 num_generations=3,那么对于每一个输入提示,模型会生成3个不同的输出。
  2. 全局批处理大小的约束:

    • 全局批处理大小(global batch size)是指所有进程和设备上批处理大小的总和,通常计算为 num_processes * per_device_batch_size
    • 这个全局批处理大小必须能够被 num_generations 整除。也就是说,global_batch_size % num_generations == 0 必须成立。
    • 这个约束确保了在分布式训练或多设备环境中,每个设备上的生成任务可以均匀分配。

2.2.3 示例

假设你有以下配置:

  • per_device_batch_size = 4
  • num_processes = 2(即你在使用两个GPU或其他并行计算单元)
  • num_generations = 8

在这种情况下:

  • 全局批处理大小为 global_batch_size = num_processes * per_device_batch_size = 2 * 4 = 8
  • 因为 global_batch_size 等于 num_generations,所以条件满足。

如果我们将 num_generations 改为 6,则:

  • 全局批处理大小仍然是 8,但 8 % 6 != 0,这会导致错误,因为无法均匀分配生成任务。

2.2.4 使用场景

  1. 文本生成

在文本生成任务中,你可能希望从一个提示生成多个不同的输出,以便选择最好的结果或者展示多样性。例如,在故事生成或对话系统中,生成多个候选答案可以让用户有更多的选择。

  1. 对话系统

在对话系统中,生成多个回复可以帮助系统提供更丰富的互动体验。通过生成多个回复,系统可以选择最合适的回答,或者让用户选择他们喜欢的回答。

  1. 多模态生成

在多模态生成任务(如图像字幕生成、视频描述等)中,生成多个输出可以提高生成内容的多样性和准确性。

2.2.5 示例代码

以下是一个简单的示例,展示了如何使用 num_generations 参数:

from transformers import AutoModelForCausalLM, AutoTokenizer# 加载预训练模型和分词器
model_name = "your-pretrained-model"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)# 定义输入提示
prompt = "Once upon a time"# 将提示编码为模型输入格式
input_ids = tokenizer(prompt, return_tensors="pt").input_ids# 设置生成参数
num_generations = 5  # 每个提示生成5个样本# 生成多个样本
outputs = model.generate(input_ids,num_return_sequences=num_generations,  # 设置num_generationsmax_length=50,do_sample=True
)# 解码生成的样本
for i, output in enumerate(outputs):print(f"Generated text {i+1}:")print(tokenizer.decode(output, skip_special_tokens=True))print()

在这个例子中,num_return_sequences 参数对应于 num_generations,它指定了要生成的样本数量。

2.3 显存占用和耗时

8卡H20,每张卡占用40~50G。
显存

要跑将15个多小时。
运行时间

3 结果

一些中间结果:

中间结果

相关文章:

【复现DeepSeek-R1之Open R1实战】系列4:跑通GRPO!

目录 1 配置环境2 训练2.1 命令和配置参数2.2 num_generations2.2.1 参数定义2.2.2 参数含义2.2.3 示例2.2.4 使用场景2.2.5 示例代码 2.3 显存占用和耗时 3 结果 1 配置环境 关于环境配置,可以参考这篇博文:【复现DeepSeek-R1之Open R1实战】系列1&…...

Redis原理简述及发布订阅消息队列

目录 1 什么是Redis 2 Redis 非阻塞IO内部原理 2.1 IO多路复用策略 2.2 Reactor设计模式 3 基于PubSub的消息队列(发布-订阅) 由于集群之后存在多台服务器,并且不同客户端连接的可能是不同的服务器,因此在聊天过程中涉及到服…...

ThreadLocal为什么会内存溢出

每个线程(Thread 对象)内部维护一个 ThreadLocalMap,用于存储该线程的所有 ThreadLocal 变量的键值对: ThreadLocalMap虽然是ThreadLocal的静态内部类,但是Thread 对象的属性,当线程存活时ThreadLocalMap不会被回收。 Key:ThreadLocal 实例的 弱引用(WeakReference)。…...

假面与演员:到底是接口在使用类,还是类在使用接口?编程接口与物理接口的区别又是什么?

前言:本篇文章解释了接口学习过程中的2个常见问题,一个是“为什么是类在使用接口”,另一个一个是“编程接口与物理接口的差异源于所处的抽象层次和交互模式的不同”,旨在揭示编程接口的本质。 Part1.是类在使用接口 当学习接口时…...

数据结构——Makefile、算法、排序(2025.2.13)

目录 一、Makefile 1.功能 2.基本语法和相关操作 (1)创建Makefile文件 (2)编译规则 (3)编译 (4)变量 ①系统变量 ②自定义变量 二、 算法 1.定义 2.算法的设计 &#xff…...

算法之 跳跃游戏

文章目录 55.跳跃游戏思路参考:56.合并区间 55.跳跃游戏 55.跳跃游戏 灵神思路 思路分析: 两种思路,思路1是我们可以直接维护当前到达i的时候所能到达的最右的边界mr,如果i>mr就说明无法到达i,否则就是可以到达;…...

C#中的图形渲染模式

在C#中,图形模式通常用于定义如何渲染或处理图形。可以枚举定义如下四种图形模式:AUTO、GDI、DIB 和 FBO。这些模式可能用于指定不同的图形渲染技术或后端。下面是对这些模式的详细解释: 1. AUTO (自动模式) 含义:自动选择最适合…...

二.数据治理流程架构

1、数据治理流程架构核心思想: 该图描绘了一个以数据标准规范体系为核心,大数据生命周期管理为主线,数据资源中心为依托,并辅以数据质量管理和大数据安全与隐私管理的数据治理流程架构。它旨在通过规范化的流程和技术手段&#x…...

瑞萨RA-T系列芯片ADCGPT功能模块的配合使用

在马达或电源工程中,往往需要采集多路AD信号,且这些信号的优先级和采样时机不相同。本篇介绍在使用RA-T系列芯片建立马达或电源工程时,如何根据需求来设置主要功能模块ADC&GPT,包括采样通道打包和分组,GPT触发启动…...

扩散模型中的马尔可夫链设计演进:从DDPM到Stable Diffusion全解析

一、技术原理与数学推导(附核心公式) 1.1 扩散过程数学建模 马尔可夫链前向过程定义: q(x_{1:T}|x_0) \prod_{t1}^T q(x_t|x_{t-1})噪声调度函数(以余弦调度为例): \beta_t \frac{1 - \cos(\pi t/T)}…...

通俗诠释 DeepSeek-V3 模型的 “671B” ,“37B”与 “128K”,用生活比喻帮你理解模型的秘密!

欢迎来到涛涛聊AI。 在DeepSeek-V3模型的参数描述中,你可能会看到类似“671B 37B 128K”这样的标记。这些字母和数字的组合看起来像密码,但其实它们揭示了模型的“大脑容量”和“工作方式”。我们用日常生活的比喻来解释: 一、数字含义&…...

大模型常识:什么是大模型/大语言模型/LLM

本文原创作者:姚瑞南 AI-agent 大模型运营专家,先后任职于美团、猎聘等中大厂AI训练专家和智能运营专家岗;多年人工智能行业智能产品运营及大模型落地经验,拥有AI外呼方向国家专利与PMP项目管理证书。(转载需经授权) 目录 一、什么是语言模型? 那么什么是语言模…...

iOS 中使用 FFmpeg 进行音视频处理

在 iOS 中使用 FFmpeg 进行音视频处理,通常需要将 FFmpeg 的功能集成到项目中。由于 FFmpeg 是一个 C 库,直接在 iOS 中使用需要进行一些配置和封装。 1. 在 iOS 项目中集成 FFmpeg 方法 1:使用 FFmpeg 预编译库 下载 FFmpeg iOS 预编译库: 可以从以下项目中获取预编译的 …...

SAP-ABAP:SAP的Screen Layout Designer屏幕布局设计器详解及示例

在SAP中,Screen Layout Designer(屏幕布局设计器)是用于设计和维护屏幕(Dynpro)布局的工具。通过Screen Layout Designer,您可以创建和修改屏幕元素(如输入字段、按钮、文本、表格控件等&#x…...

一.数据治理理论架构

1、数据治理核心思想: 数据治理理论架构图描绘了一个由顶层设计、管控机制、核心领域和管理系统四个主要部分组成的数据治理框架。它旨在通过系统化的方法,解决数据治理机制缺失引发的业务和技术问题,并最终提升企业的数据管理水平。 数据治…...

亲测有效!使用Ollama本地部署DeepSeekR1模型,指定目录安装并实现可视化聊天与接口调用

文章目录 一、引言二、准备工作(Ollama 工具介绍与下载)2.1 Ollama介绍2.2 Ollama安装 三、指定目录安装 DeepSeek R1四、Chatbox 可视化聊天搭建4.1 Chatbox下载安装4.2 关联 DeepSeek R1 与 Chatbox 的步骤 五、使用 Ollama 调用 DeepSeek 接口5.1 请求…...

MySQL安装MySQL服务时提示Install-Remove of the Service Denied

文章目录 问题描述排查1.字面意思2.搜索引擎3.官方文档4.源码 处理方法相关扩展 问题描述 MySQL安装MySQL服务时提示Install-Remove of the Service Denied! 详细报错如下: C:\Users\荷塘月色>net start mysql 服务名无效。请键入 NET HELPMSG 2185 以获得更多…...

(Windows | Linux)ssh访问服务器报错:no matching key exchange method found

问题现象 ssh user1192.168.1X.XX Unable to negotiate with 192.168.1X.XX port 22: no matching key exchange method found. Their offer: gss-group1-sha1-toWM5Slw5Ew8Mqkayal2g,diffie-hellman-group-exchange-sha1,diffie-hellman-group14-sha1,diffie-hellman-group1-…...

Linux(centos)系统安装部署MySQL8.0数据库(GLIBC版本)

安装前检查服务器glibc版本,下载对应版本包 rpm -qa | grep glibc mysql安装包及依赖包已整理好,下载地址:https://pan.quark.cn/s/3137acc814c0,下载即可安装 一、下载MySQL mysql安装包及依赖包已整理好,下载地址…...

有哪些滤波,原理是什么,分别在什么时候用

均值滤波(Average Filtering) 原理:通过计算像素点邻域内像素值的平均值来作为该像素点滤波后的新值。例如,对于一个 3x3 的邻域,将 9 个像素值相加然后除以 9 得到滤波后的像素值。优点:简单易实现&#x…...

谷歌浏览器插件

项目中有时候会用到插件 sync-cookie-extension1.0.0:开发环境同步测试 cookie 至 localhost,便于本地请求服务携带 cookie 参考地址:https://juejin.cn/post/7139354571712757767 里面有源码下载下来,加在到扩展即可使用FeHelp…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

实现弹窗随键盘上移居中

实现弹窗随键盘上移的核心思路 在Android中&#xff0c;可以通过监听键盘的显示和隐藏事件&#xff0c;动态调整弹窗的位置。关键点在于获取键盘高度&#xff0c;并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

Mobile ALOHA全身模仿学习

一、题目 Mobile ALOHA&#xff1a;通过低成本全身远程操作学习双手移动操作 传统模仿学习&#xff08;Imitation Learning&#xff09;缺点&#xff1a;聚焦与桌面操作&#xff0c;缺乏通用任务所需的移动性和灵活性 本论文优点&#xff1a;&#xff08;1&#xff09;在ALOHA…...

AI病理诊断七剑下天山,医疗未来触手可及

一、病理诊断困局&#xff1a;刀尖上的医学艺术 1.1 金标准背后的隐痛 病理诊断被誉为"诊断的诊断"&#xff0c;医生需通过显微镜观察组织切片&#xff0c;在细胞迷宫中捕捉癌变信号。某省病理质控报告显示&#xff0c;基层医院误诊率达12%-15%&#xff0c;专家会诊…...

pikachu靶场通关笔记19 SQL注入02-字符型注入(GET)

目录 一、SQL注入 二、字符型SQL注入 三、字符型注入与数字型注入 四、源码分析 五、渗透实战 1、渗透准备 2、SQL注入探测 &#xff08;1&#xff09;输入单引号 &#xff08;2&#xff09;万能注入语句 3、获取回显列orderby 4、获取数据库名database 5、获取表名…...

针对药品仓库的效期管理问题,如何利用WMS系统“破局”

案例&#xff1a; 某医药分销企业&#xff0c;主要经营各类药品的批发与零售。由于药品的特殊性&#xff0c;效期管理至关重要&#xff0c;但该企业一直面临效期问题的困扰。在未使用WMS系统之前&#xff0c;其药品入库、存储、出库等环节的效期管理主要依赖人工记录与检查。库…...

解析“道作为序位生成器”的核心原理

解析“道作为序位生成器”的核心原理 以下完整展开道函数的零点调控机制&#xff0c;重点解析"道作为序位生成器"的核心原理与实现框架&#xff1a; 一、道函数的零点调控机制 1. 道作为序位生成器 道在认知坐标系$(x_{\text{物}}, y_{\text{意}}, z_{\text{文}}…...

路由基础-路由表

本篇将会向读者介绍路由的基本概念。 前言 在一个典型的数据通信网络中&#xff0c;往往存在多个不同的IP网段&#xff0c;数据在不同的IP网段之间交互是需要借助三层设备的&#xff0c;这些设备具备路由能力&#xff0c;能够实现数据的跨网段转发。 路由是数据通信网络中最基…...

【Vue】scoped+组件通信+props校验

【scoped作用及原理】 【作用】 默认写在组件中style的样式会全局生效, 因此很容易造成多个组件之间的样式冲突问题 故而可以给组件加上scoped 属性&#xff0c; 令样式只作用于当前组件的标签 作用&#xff1a;防止不同vue组件样式污染 【原理】 给组件加上scoped 属性后…...