当前位置: 首页 > news >正文

Pytorch训练固定随机种子(单卡场景和分布式训练场景)

模型的训练是一个随机过程,固定随机种子可以帮助我们复现实验结果。

接下来介绍一个模型训练过程中固定随机种子的代码,并对每条语句的作用都会进行解释。

def seed_reproducer(seed=2333):random.seed(seed)os.environ["PYTHONHASHSEED"] = str(seed)np.random.seed(seed)torch.manual_seed(seed)if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = True

这是一个自定义函数,函数的参数就是我们传入的种子的数值,类型为int,作用就是消除训练过程中的随机性,以确保实验得可重复性,具体使用方法为在初始化模型和dataset前调用该函数即可。

接下来逐句讲解每个语句的作用。


random.seed(seed)

  • 作用:设置python内置random模块的种子,确保所有使用random模块产生的随机数序列是确定的;
  • 应用场景:适用于任何使用random模块的地方,例如数据增强、采样等。

os.environ["PYTHONHASHSEED"] = str(seed)

  • 作用:设置python的哈希种子,python字典和其他哈希表结构依赖于哈希函数,而哈希函数的行为在不同运行之间可能会不同,通过设置PYTHONHASHSEED环境变量,可使哈希结果在同一种子下保持一致;
  • 应用场景:确保字典键值对顺序的一致性,避免因哈希碰撞引起的非确定性行为。

np.random.seed(seed)

  • 作用:设置Numpy库的随机种子;
  • 应用场景:适用于所有使用Numpy生成随机数的地方,例如初始化权重、数据打乱等。

torch.manual_seed(seed)

  • 作用:为Pytorch设置全局种子,确保所有使用Pytorch身材的随机数(包括张量操作)都是确定性的;
  • 应用场景:适用于所有使用Pytorch生成随机数的地方。

 为了方便结合代码进行理解,在这里单独把后半部分的代码复制一下。

if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = True

if torch.cuda.is_available():

条件判断,只有当CUDA可用时才进行以下设置,这确保了代码可以在没有GPU的环境中正常运行。

torch.cuda.manual_seed(seed)

  • 作用:为当前GPU设置随机种子,这确保了所有在当前GPU上生成的随机数都是确定的;
  • 应用场景:适用于单个GPU的随机数生成。

torch.cuda.manual_seed_all(seed)

  • 作用:为所有可用的GPU设置相同的随机种子,这确保了在多GPU环境中,每个GPU上随机生成的随机数都是一致的;
  • 应用场景:适用于多GPU环境中的随机数生成。

torch.backends.cudnn.deterministic = True

  • 作用:确保cuDNN使用确定性的算法,某些cuDNN算法是具有随机性的,启用此选项可以提高结果的可重复性,但是可能会降低性能;
  • 应用场景:适用于需要严格可重复性的实验。

torch.backends.cudnn.benchmark = False

  • 作用:禁用cuDNN的自动选择最佳卷积算法的功能,默认情况下cuDNN会在首次运行时尝试找到最适合硬件的算法,这可能会导致结果的不确定性,禁用此选项可以确保每次都是用相同的算法;
  • 应用场景:适用于需要严格可重复性的实验。

torch.backends.cudnn.enabled = True

  • 作用:启用cuDNN,虽然设置了deterministicbenchmark参数来控制cuDNN的行为,但仍然需要确保cuDNN是启用的;
  • 应用场景:所有需要使用GPU加速pytorch计算的场景。

多卡训练的情况

对于多卡训练的情况,设置随机种子的方式需要特别注意,以确保每个进程(或者称为“rank”)生成的随机数序列是不同的,同时还需要保证整个训练过程的可重复性。基于上述要求,随机种子可使用如下代码进行修改:

seed = args.seed + utils.get_rank()
  • args.seed:我们指定的seed的数值,通常在模型训练的配置文件中,或者通过命令行参数传入模型的训练脚本;
  • utils.get_rank():这是自定义的一个函数,位于自定义库utils中,具体作用是获得当前进程的全局序号,比如在进行分布式训练时,有2台机器(我们成之为2个节点),每台机器有8张GPU,则一共会有2*8=16个进程,每个进程都会有唯一的序号,从0~15。

因此我们看到,对于分布式训练场景,随机种子只需要确定一个全局种子(args.seed)在加上一个增量(utils.get_rank()),而这个增量对于每个进程来说是固定的。

下面是get_rank()函数的具体实现:

import torch.distributed as distdef is_dist_avail_and_initialized():# 判断当前环境中是否支持分布式训练if not dist.is_available():return False# 检查当前环境是否已经成功初始化了分布式训练环境if not dist.is_initialized():return Falsereturn Truedef get_rank():# 判断分布式训练是否可用且是否已成功初始化if not is_dist_avail_and_initialized():return 0return dist.get_rank()

相关文章:

Pytorch训练固定随机种子(单卡场景和分布式训练场景)

模型的训练是一个随机过程,固定随机种子可以帮助我们复现实验结果。 接下来介绍一个模型训练过程中固定随机种子的代码,并对每条语句的作用都会进行解释。 def seed_reproducer(seed2333):random.seed(seed)os.environ["PYTHONHASHSEED"] s…...

Conda + JuiceFS :增强 AI 开发环境共享能力

Conda 是当前 AI 应用开发领域中非常流行的环境和包管理系统,因其能够简单便捷地创建与系统资源相隔离的虚拟环境广受欢迎。 Conda 支持在不同的操作系统上重建相同的工作环境,但在环境共享复用方面仍存在一些挑战。比如,在不同机器上复用相…...

人工智能-人机交互的机会

目录 引言HCI领域的发展机会人工智能领域的崛起与机会博雅智信的HCI与AI辅导服务结语 引言 在人类科技不断进步的今天,HCI(人机交互)和人工智能(AI)是两个密切相关且充满潜力的领域。HCI研究如何优化人类与计算机之间…...

【系统架构核心服务设计】使用 Redis ZSET 实现排行榜服务

目录 一、排行榜的应用场景 二、排行榜技术的特点 三、使用Redis ZSET实现排行榜 3.1 引入依赖 3.2 配置Redis连接 3.3 创建实体类(可选) 3.4 编写 Redis 操作服务层 3.5 编写控制器层 3.6 测试 3.6.1 测试 addMovieScore 接口 3.6.2 测试 g…...

elasticsearch基础总结

最近实习,项目用的elasticseatch做的存储库,但是之前对于es接触的不多,查询语法有些不熟,每次想写个DSL查询时都要gpt或者施展搜索大法,所以索性就自己总结总结,以后忘了也方便查。所以这篇文章会持续更新。…...

【慕伏白教程】Zerotier 连接与简单配置

文章目录 下载与安装WindowsLinuxapt安装官方脚本安装 Zerotier 配置新建网络网络配置 终端配置WindowsLinux 下载与安装 Windows 进入Zerotier官方下载网站,点击下载 在下载目录找到安装文件,双击打开后点击 Install 开始安装 安装完成后,…...

Brain.js(九):LSTMTimeStep 实战教程 - 未来短期内的股市指数预测 - 实操要谨慎

系列的前一文RNNTimeStep 实战教程 - 股票价格预测 讲述了如何使用RNN时间序列预测实时的股价, 在这一节中,我们将深入学习如何利用 JavaScript 在浏览器环境下使用 LSTMTimeStep 进行股市指数的短期预测。通过本次实战教程,你将了解到如何用…...

C# 字符串(String)

文章目录 前言创建 String 对象的方式1. 通过给 String 变量指定一个字符串2. 通过使用 String 类构造函数3. 通过使用字符串串联运算符( )4. 通过检索属性或调用一个返回字符串的方法5. 通过格式化方法来转换一个值或对象为它的字符串表示形式 String …...

二进制文件

大多数人听到“二进制”的时候,脑海里可能马上就会联想到电影《黑客帝国》中由“0”和“1”组成的矩阵。 笔者不打算在这里详细讨论二进制的运算、反码、补码之类枯燥的东西,但有几个和开发相关的概念需要做一点澄清和普及。因为这些内容就像空气——用…...

【电子元器件】音频功放种类

本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时,也能帮助其他需要参考的朋友。如有谬误,欢迎大家进行指正。 一、概述 音频功放将小信号的幅值提高至有用电平,同时保留小信号的细节,这称为线性度。放大器的线性…...

linux之vim

一、模式转换命令 vim主要有三种模式:命令模式(Normal Mode)、输入模式(Insert Mode)和底线命令模式(Command-Line Mode)。 从命令模式切换到输入模式:i:在当前光标所在…...

QT的ui界面显示不全问题(适应高分辨率屏幕)

//自动适应高分辨率 QCoreApplication::setAttribute(Qt::AA_EnableHighDpiScaling);一、问题 电脑分辨率高,默认情况下,打开QT的ui界面,显示不全按钮内容 二、解决方案 如果自己的电脑分辨率较高,可以尝试以下方案:自…...

数据结构--串、数组和广义表

串 定义:串(String)是由零个或多个字符组成的有限序列。 子串:串中任意个连续字符组成的子序列称为该串的子串。 主串:包含子串的串相应地称为主串。 字符位置:字符在该序列中的序号为该字符在串中的位置…...

LLMs之Agent之Lares:Lares的简介、安装和使用方法、案例应用之详细攻略

LLMs之Agent之Lares:Lares的简介、安装和使用方法、案例应用之详细攻略 导读:这篇博文介绍了 Lares,一个由简单的 AI 代理驱动的智能家居助手模拟器,它展现出令人惊讶的解决问题能力。 >> 背景痛点:每天都有新的…...

1-1.mysql2 之 mysql2 初识(mysql2 初识案例、初识案例挖掘)

一、mysql2 概述 mysql2 是一个用于 Node.js 的 MySQL 客户端库 mysql2 是 mysql 库的一个改进版本,提供了更好的性能和更多的功能 使用 mysql2 之前,需要先安装它 npm install mysql2 二、mysql2 初识案例 1、数据库准备 创建数据库 testdb CREAT…...

企业邮箱为什么不能经常群发邮件?

企业邮箱是用企业域名作为后缀的邮箱,虽然企业邮箱确实具备群发邮件的功能,但它更适用于企业内部的群发,而非用于外部推广。如果是在企业邮件域内进行群发,通常可以借助企业邮箱的邮件列表来实现。然而,对于域外的大量…...

集成运算放大电路反馈判断

集成运算放大电路 一种具有很高放大倍数的多级直接耦合放大电路,因最初用于信号运算而得名,简称集成运放或运放 模拟集成电路中的典型组件,是发展最快、品种最多、应用最广的一种 反馈 将放大电路输出信号的一部分或全部通过某种电路引回到输…...

媒体查询、浏览器一帧渲染过程

文章目录 媒体查询语法示例根据视口宽度应用不同的样式根据设备像素比应用不同的样式根据方向应用不同的样式 使用场景 浏览器一帧的渲染过程 媒体查询 媒体查询(Media Query)是CSS3中的一个重要特性,它允许开发者根据设备的特定条件&#x…...

高级排序算法(一):快速排序详解

引言 当我们处理大规模数据时,像冒泡排序、选择排序这样的基础排序算法就有点力不从心了。这时候,快速排序(Quick Sort)就派上用场了。 作为一种基于分治法的高效排序算法,快速排序在大多数情况下可以在O(n log n)的时…...

3.2 网络协议IP

欢迎大家订阅【计算机网络】学习专栏,开启你的计算机网络学习之旅! 文章目录 1 定义2 虚拟互连网络3 分组在互联网中的传送4 IPv4 地址 1 定义 网际协议 IP是 TCP/IP 体系中两个最主要的协议之一,也是最重要的互连网协议之一。IPv4 和 IPv6 …...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...

Flask RESTful 示例

目录 1. 环境准备2. 安装依赖3. 修改main.py4. 运行应用5. API使用示例获取所有任务获取单个任务创建新任务更新任务删除任务 中文乱码问题: 下面创建一个简单的Flask RESTful API示例。首先,我们需要创建环境,安装必要的依赖,然后…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...

边缘计算医疗风险自查APP开发方案

核心目标:在便携设备(智能手表/家用检测仪)部署轻量化疾病预测模型,实现低延迟、隐私安全的实时健康风险评估。 一、技术架构设计 #mermaid-svg-iuNaeeLK2YoFKfao {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

土地利用/土地覆盖遥感解译与基于CLUE模型未来变化情景预测;从基础到高级,涵盖ArcGIS数据处理、ENVI遥感解译与CLUE模型情景模拟等

🔍 土地利用/土地覆盖数据是生态、环境和气象等诸多领域模型的关键输入参数。通过遥感影像解译技术,可以精准获取历史或当前任何一个区域的土地利用/土地覆盖情况。这些数据不仅能够用于评估区域生态环境的变化趋势,还能有效评价重大生态工程…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年,作为行业领先的3D工业相机及视觉系统供应商,累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成,通过稳定、易用、高回报的AI3D视觉系统,为汽车、新能源、金属制造等行…...

IT供电系统绝缘监测及故障定位解决方案

随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...

深度学习习题2

1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...

2025季度云服务器排行榜

在全球云服务器市场,各厂商的排名和地位并非一成不变,而是由其独特的优势、战略布局和市场适应性共同决定的。以下是根据2025年市场趋势,对主要云服务器厂商在排行榜中占据重要位置的原因和优势进行深度分析: 一、全球“三巨头”…...