Pytorch训练固定随机种子(单卡场景和分布式训练场景)
模型的训练是一个随机过程,固定随机种子可以帮助我们复现实验结果。
接下来介绍一个模型训练过程中固定随机种子的代码,并对每条语句的作用都会进行解释。
def seed_reproducer(seed=2333):random.seed(seed)os.environ["PYTHONHASHSEED"] = str(seed)np.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = True
这是一个自定义函数,函数的参数就是我们传入的种子的数值,类型为int,作用就是消除训练过程中的随机性,以确保实验得可重复性,具体使用方法为在初始化模型和dataset前调用该函数即可。
接下来逐句讲解每个语句的作用。
random.seed(seed)
- 作用:设置python内置random模块的种子,确保所有使用random模块产生的随机数序列是确定的;
- 应用场景:适用于任何使用random模块的地方,例如数据增强、采样等。
os.environ["PYTHONHASHSEED"] = str(seed)
- 作用:设置python的哈希种子,python字典和其他哈希表结构依赖于哈希函数,而哈希函数的行为在不同运行之间可能会不同,通过设置PYTHONHASHSEED环境变量,可使哈希结果在同一种子下保持一致;
- 应用场景:确保字典键值对顺序的一致性,避免因哈希碰撞引起的非确定性行为。
np.random.seed(seed)
- 作用:设置Numpy库的随机种子;
- 应用场景:适用于所有使用Numpy生成随机数的地方,例如初始化权重、数据打乱等。
torch.manual_seed(seed)
- 作用:为Pytorch设置全局种子,确保所有使用Pytorch身材的随机数(包括张量操作)都是确定性的;
- 应用场景:适用于所有使用Pytorch生成随机数的地方。
为了方便结合代码进行理解,在这里单独把后半部分的代码复制一下。
if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = True
if torch.cuda.is_available():
条件判断,只有当CUDA可用时才进行以下设置,这确保了代码可以在没有GPU的环境中正常运行。
torch.cuda.manual_seed(seed)
- 作用:为当前GPU设置随机种子,这确保了所有在当前GPU上生成的随机数都是确定的;
- 应用场景:适用于单个GPU的随机数生成。
torch.cuda.manual_seed_all(seed)
- 作用:为所有可用的GPU设置相同的随机种子,这确保了在多GPU环境中,每个GPU上随机生成的随机数都是一致的;
- 应用场景:适用于多GPU环境中的随机数生成。
torch.backends.cudnn.deterministic = True
- 作用:确保cuDNN使用确定性的算法,某些cuDNN算法是具有随机性的,启用此选项可以提高结果的可重复性,但是可能会降低性能;
- 应用场景:适用于需要严格可重复性的实验。
torch.backends.cudnn.benchmark = False
- 作用:禁用cuDNN的自动选择最佳卷积算法的功能,默认情况下cuDNN会在首次运行时尝试找到最适合硬件的算法,这可能会导致结果的不确定性,禁用此选项可以确保每次都是用相同的算法;
- 应用场景:适用于需要严格可重复性的实验。
torch.backends.cudnn.enabled = True
- 作用:启用cuDNN,虽然设置了deterministic和benchmark参数来控制cuDNN的行为,但仍然需要确保cuDNN是启用的;
- 应用场景:所有需要使用GPU加速pytorch计算的场景。
多卡训练的情况
对于多卡训练的情况,设置随机种子的方式需要特别注意,以确保每个进程(或者称为“rank”)生成的随机数序列是不同的,同时还需要保证整个训练过程的可重复性。基于上述要求,随机种子可使用如下代码进行修改:
seed = args.seed + utils.get_rank()
- args.seed:我们指定的seed的数值,通常在模型训练的配置文件中,或者通过命令行参数传入模型的训练脚本;
- utils.get_rank():这是自定义的一个函数,位于自定义库utils中,具体作用是获得当前进程的全局序号,比如在进行分布式训练时,有2台机器(我们成之为2个节点),每台机器有8张GPU,则一共会有2*8=16个进程,每个进程都会有唯一的序号,从0~15。
因此我们看到,对于分布式训练场景,随机种子只需要确定一个全局种子(args.seed)在加上一个增量(utils.get_rank()),而这个增量对于每个进程来说是固定的。
下面是get_rank()函数的具体实现:
import torch.distributed as distdef is_dist_avail_and_initialized():# 判断当前环境中是否支持分布式训练if not dist.is_available():return False# 检查当前环境是否已经成功初始化了分布式训练环境if not dist.is_initialized():return Falsereturn Truedef get_rank():# 判断分布式训练是否可用且是否已成功初始化if not is_dist_avail_and_initialized():return 0return dist.get_rank()
相关文章:
Pytorch训练固定随机种子(单卡场景和分布式训练场景)
模型的训练是一个随机过程,固定随机种子可以帮助我们复现实验结果。 接下来介绍一个模型训练过程中固定随机种子的代码,并对每条语句的作用都会进行解释。 def seed_reproducer(seed2333):random.seed(seed)os.environ["PYTHONHASHSEED"] s…...
Conda + JuiceFS :增强 AI 开发环境共享能力
Conda 是当前 AI 应用开发领域中非常流行的环境和包管理系统,因其能够简单便捷地创建与系统资源相隔离的虚拟环境广受欢迎。 Conda 支持在不同的操作系统上重建相同的工作环境,但在环境共享复用方面仍存在一些挑战。比如,在不同机器上复用相…...
人工智能-人机交互的机会
目录 引言HCI领域的发展机会人工智能领域的崛起与机会博雅智信的HCI与AI辅导服务结语 引言 在人类科技不断进步的今天,HCI(人机交互)和人工智能(AI)是两个密切相关且充满潜力的领域。HCI研究如何优化人类与计算机之间…...
【系统架构核心服务设计】使用 Redis ZSET 实现排行榜服务
目录 一、排行榜的应用场景 二、排行榜技术的特点 三、使用Redis ZSET实现排行榜 3.1 引入依赖 3.2 配置Redis连接 3.3 创建实体类(可选) 3.4 编写 Redis 操作服务层 3.5 编写控制器层 3.6 测试 3.6.1 测试 addMovieScore 接口 3.6.2 测试 g…...
elasticsearch基础总结
最近实习,项目用的elasticseatch做的存储库,但是之前对于es接触的不多,查询语法有些不熟,每次想写个DSL查询时都要gpt或者施展搜索大法,所以索性就自己总结总结,以后忘了也方便查。所以这篇文章会持续更新。…...
【慕伏白教程】Zerotier 连接与简单配置
文章目录 下载与安装WindowsLinuxapt安装官方脚本安装 Zerotier 配置新建网络网络配置 终端配置WindowsLinux 下载与安装 Windows 进入Zerotier官方下载网站,点击下载 在下载目录找到安装文件,双击打开后点击 Install 开始安装 安装完成后,…...
Brain.js(九):LSTMTimeStep 实战教程 - 未来短期内的股市指数预测 - 实操要谨慎
系列的前一文RNNTimeStep 实战教程 - 股票价格预测 讲述了如何使用RNN时间序列预测实时的股价, 在这一节中,我们将深入学习如何利用 JavaScript 在浏览器环境下使用 LSTMTimeStep 进行股市指数的短期预测。通过本次实战教程,你将了解到如何用…...
C# 字符串(String)
文章目录 前言创建 String 对象的方式1. 通过给 String 变量指定一个字符串2. 通过使用 String 类构造函数3. 通过使用字符串串联运算符( )4. 通过检索属性或调用一个返回字符串的方法5. 通过格式化方法来转换一个值或对象为它的字符串表示形式 String …...
二进制文件
大多数人听到“二进制”的时候,脑海里可能马上就会联想到电影《黑客帝国》中由“0”和“1”组成的矩阵。 笔者不打算在这里详细讨论二进制的运算、反码、补码之类枯燥的东西,但有几个和开发相关的概念需要做一点澄清和普及。因为这些内容就像空气——用…...
【电子元器件】音频功放种类
本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时,也能帮助其他需要参考的朋友。如有谬误,欢迎大家进行指正。 一、概述 音频功放将小信号的幅值提高至有用电平,同时保留小信号的细节,这称为线性度。放大器的线性…...
linux之vim
一、模式转换命令 vim主要有三种模式:命令模式(Normal Mode)、输入模式(Insert Mode)和底线命令模式(Command-Line Mode)。 从命令模式切换到输入模式:i:在当前光标所在…...
QT的ui界面显示不全问题(适应高分辨率屏幕)
//自动适应高分辨率 QCoreApplication::setAttribute(Qt::AA_EnableHighDpiScaling);一、问题 电脑分辨率高,默认情况下,打开QT的ui界面,显示不全按钮内容 二、解决方案 如果自己的电脑分辨率较高,可以尝试以下方案:自…...
数据结构--串、数组和广义表
串 定义:串(String)是由零个或多个字符组成的有限序列。 子串:串中任意个连续字符组成的子序列称为该串的子串。 主串:包含子串的串相应地称为主串。 字符位置:字符在该序列中的序号为该字符在串中的位置…...
LLMs之Agent之Lares:Lares的简介、安装和使用方法、案例应用之详细攻略
LLMs之Agent之Lares:Lares的简介、安装和使用方法、案例应用之详细攻略 导读:这篇博文介绍了 Lares,一个由简单的 AI 代理驱动的智能家居助手模拟器,它展现出令人惊讶的解决问题能力。 >> 背景痛点:每天都有新的…...
1-1.mysql2 之 mysql2 初识(mysql2 初识案例、初识案例挖掘)
一、mysql2 概述 mysql2 是一个用于 Node.js 的 MySQL 客户端库 mysql2 是 mysql 库的一个改进版本,提供了更好的性能和更多的功能 使用 mysql2 之前,需要先安装它 npm install mysql2 二、mysql2 初识案例 1、数据库准备 创建数据库 testdb CREAT…...
企业邮箱为什么不能经常群发邮件?
企业邮箱是用企业域名作为后缀的邮箱,虽然企业邮箱确实具备群发邮件的功能,但它更适用于企业内部的群发,而非用于外部推广。如果是在企业邮件域内进行群发,通常可以借助企业邮箱的邮件列表来实现。然而,对于域外的大量…...
集成运算放大电路反馈判断
集成运算放大电路 一种具有很高放大倍数的多级直接耦合放大电路,因最初用于信号运算而得名,简称集成运放或运放 模拟集成电路中的典型组件,是发展最快、品种最多、应用最广的一种 反馈 将放大电路输出信号的一部分或全部通过某种电路引回到输…...
媒体查询、浏览器一帧渲染过程
文章目录 媒体查询语法示例根据视口宽度应用不同的样式根据设备像素比应用不同的样式根据方向应用不同的样式 使用场景 浏览器一帧的渲染过程 媒体查询 媒体查询(Media Query)是CSS3中的一个重要特性,它允许开发者根据设备的特定条件&#x…...
高级排序算法(一):快速排序详解
引言 当我们处理大规模数据时,像冒泡排序、选择排序这样的基础排序算法就有点力不从心了。这时候,快速排序(Quick Sort)就派上用场了。 作为一种基于分治法的高效排序算法,快速排序在大多数情况下可以在O(n log n)的时…...
3.2 网络协议IP
欢迎大家订阅【计算机网络】学习专栏,开启你的计算机网络学习之旅! 文章目录 1 定义2 虚拟互连网络3 分组在互联网中的传送4 IPv4 地址 1 定义 网际协议 IP是 TCP/IP 体系中两个最主要的协议之一,也是最重要的互连网协议之一。IPv4 和 IPv6 …...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...
蓝牙 BLE 扫描面试题大全(2):进阶面试题与实战演练
前文覆盖了 BLE 扫描的基础概念与经典问题蓝牙 BLE 扫描面试题大全(1):从基础到实战的深度解析-CSDN博客,但实际面试中,企业更关注候选人对复杂场景的应对能力(如多设备并发扫描、低功耗与高发现率的平衡)和前沿技术的…...
Qt Http Server模块功能及架构
Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...
SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...
VTK如何让部分单位不可见
最近遇到一个需求,需要让一个vtkDataSet中的部分单元不可见,查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行,是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示,主要是最后一个参数,透明度…...
CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
今日科技热点速览
🔥 今日科技热点速览 🎮 任天堂Switch 2 正式发售 任天堂新一代游戏主机 Switch 2 今日正式上线发售,主打更强图形性能与沉浸式体验,支持多模态交互,受到全球玩家热捧 。 🤖 人工智能持续突破 DeepSeek-R1&…...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...
根据万维钢·精英日课6的内容,使用AI(2025)可以参考以下方法:
根据万维钢精英日课6的内容,使用AI(2025)可以参考以下方法: 四个洞见 模型已经比人聪明:以ChatGPT o3为代表的AI非常强大,能运用高级理论解释道理、引用最新学术论文,生成对顶尖科学家都有用的…...
Pinocchio 库详解及其在足式机器人上的应用
Pinocchio 库详解及其在足式机器人上的应用 Pinocchio (Pinocchio is not only a nose) 是一个开源的 C 库,专门用于快速计算机器人模型的正向运动学、逆向运动学、雅可比矩阵、动力学和动力学导数。它主要关注效率和准确性,并提供了一个通用的框架&…...
