当前位置: 首页 > news >正文

Pytorch指定数据加载器使用子进程

torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)

num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处理数据,从而提高数据加载的速度。

通常情况下,增加 num_workers 的数量可以提高数据加载的效率,因为它可以使数据加载和预处理工作在多个进程中同时进行。然而,当 num_workers 的数量超过一定阈值时,增加更多的进程可能不会再带来更多的性能提升,甚至可能会导致性能下降。

这是因为增加 num_workers 的数量也会增加进程间通信的开销。当 num_workers 的数量过多时,进程间通信的开销可能会超过并行化所带来的收益,从而导致性能下降。

此外,还需要考虑到计算机硬件的限制。如果你的计算机 CPU 核心数量有限,增加 num_workers 的数量也可能会导致性能下降,因为每个进程需要占用 CPU 核心资源。

因此,对于 num_workers 参数的设置,需要根据具体情况进行调整和优化。通常情况下,一个合理的 num_workers 值应该在 2 到 8 之间,具体取决于你的计算机硬件配置和数据集大小等因素。在实际应用中,可以通过尝试不同的 num_workers 值来找到最优的配置。

综上所述,当 num_workers 的值从 4 增加到 8 时,如果你的计算机硬件配置和数据集大小等因素没有发生变化,那么两者之间的性能差异可能会很小,或者甚至没有显著差异。

测试代码如下

import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import timeif __name__ == '__main__':mp.freeze_support()train_on_gpu = torch.cuda.is_available()if not train_on_gpu:print('CUDA is not available. Training on CPU...')else:print('CUDA is available! Training on GPU...')device = torch.device("cuda" if torch.cuda.is_available() else "cpu")batch_size = 4# 设置数据预处理的转换transform = torchvision.transforms.Compose([torchvision.transforms.Resize((512,512)),  # 调整图像大小为 224x224torchvision.transforms.ToTensor(),  # 转换为张量torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 归一化])dataset = torchvision.datasets.ImageFolder('C:\\Users\\ASUS\\PycharmProjects\\pythonProject1\\cats_and_dogs_train',transform=transform)val_ratio = 0.2val_size = int(len(dataset) * val_ratio)train_size = len(dataset) - val_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)val_dataset = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True,num_workers=4, pin_memory=True)model = models.resnet18()num_classes = 2for param in model.parameters():param.requires_grad = Falsemodel.fc = nn.Sequential(nn.Dropout(),nn.Linear(model.fc.in_features, num_classes),nn.LogSoftmax(dim=1))optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss().to(device)model.to(device)filename = "recognize_cats_and_dogs.pt"def save_checkpoint(epoch, model, optimizer, filename):checkpoint = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}torch.save(checkpoint, filename)num_epochs = 3train_loss = []for epoch in range(num_epochs):running_loss = 0correct = 0total = 0epoch_start_time = time.time()for i, (inputs, labels) in enumerate(train_dataset):# 将数据放到设备上inputs, labels = inputs.to(device), labels.to(device)# 前向计算outputs = model(inputs)# 计算损失和梯度loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()# 更新模型参数optimizer.step()# 记录损失和准确率running_loss += loss.item()train_loss.append(loss.item())_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()total += labels.size(0)accuracy_train = 100 * correct / total# 在测试集上计算准确率with torch.no_grad():running_loss_test = 0correct_test = 0total_test = 0for inputs, labels in val_dataset:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss_test += loss.item()_, predicted = torch.max(outputs.data, 1)correct_test += (predicted == labels).sum().item()total_test += labels.size(0)accuracy_test = 100 * correct_test / total_test# 输出每个 epoch 的损失和准确率epoch_end_time = time.time()epoch_time = epoch_end_time - epoch_start_timeprint("Epoch [{}/{}], Time: {:.4f}s, Loss: {:.4f}, Train Accuracy: {:.2f}%, Loss: {:.4f}, Test Accuracy: {:.2f}%".format(epoch + 1, num_epochs,epoch_time,running_loss / len(val_dataset),accuracy_train, running_loss_test / len(val_dataset), accuracy_test))save_checkpoint(epoch, model, optimizer, filename)plt.plot(train_loss, label='Train Loss')# 添加图例和标签plt.legend()plt.xlabel('Epochs')plt.ylabel('Loss')plt.title('Training Loss')# 显示图形plt.show()

不同num_workers的结果如下

相关文章:

Pytorch指定数据加载器使用子进程

torch.utils.data.DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue,num_workers4, pin_memoryTrue) num_workers 参数是 DataLoader 类的一个参数,它指定了数据加载器使用的子进程数量。通过增加 num_workers 的数量,可以并行地读取和预处…...

【科普】干货!带你从0了解移动机器人(六) (底盘结构类型)

牵引式移动机器人(AGV/AMR),通常由一个牵引车和一个或多个被牵引的车辆组成。牵引车是机器人的核心部分,它具有自主导航和定位功能,可以根据预先设定的路径或地标进行导航,并通过传感器和视觉系统感知周围环…...

爆肝整理,Pytest+Allure+Jenkins自动化测试集成实战(图文详细步骤)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 1、简介 pytesta…...

微信批量添加好友,让你的人脉迅速增长

在这个数字化时代,微信作为中国最流行的社交平台之一,已经成为了人们生活中不可或缺的一部分。它的广泛使用为我们提供了无限的社交可能性。你是否曾为了扩大人脉圈子而犯愁?今天,我将向你揭示一个高效添加微信好友的秘密武器&…...

3D模型怎么贴法线贴图?

1、法线贴图的原理? 法线贴图(normal mapping)是一种计算机图形技术,用于在低多边形模型上模拟高多边形模型的细节效果。它通过在纹理坐标上存储和应用法线向量的信息来实现。 法线贴图的原理基于光照模型。在渲染过程中&#x…...

QT中文乱码解决方案与乱码的原因

相信大家应该都遇到过中文乱码的问题,有时候改一改中文就不乱码了,但是有时候用同样的方式还是乱码,那么这个乱码到底是什么原因,又该如何彻底解决呢? 总结 先总结一下: Qt5中,将QString()的构…...

sam9x60 boot

...

3D模型格式转换工具HOOPS Exchange:支持国际标准STEP格式!

HOOPS Exchange SDK是一组C软件库,使开发团队能够快速将可靠的2D和3D CAD导入和导出添加到其应用程序中,访问广泛的数据,包括边界表示 (B-REP)、产品制造信息 (PMI)、模型树、视图、持久 ID、样式、构造几何、可视化等,无需依赖任…...

java--死循环与循环嵌套

1.死循环 可以一直执行下去的一种循环,如果没有干预不会停下来的 2.死循环的写法 3.循环嵌套 循环中又包含循环 4.循环嵌套的特点 外部循环每循环一次,内部循环会全部执行完一轮...

基于机器视觉的图像拼接算法 计算机竞赛

前言 图像拼接在实际的应用场景很广,比如无人机航拍,遥感图像等等,图像拼接是进一步做图像理解基础步骤,拼接效果的好坏直接影响接下来的工作,所以一个好的图像拼接算法非常重要。 再举一个身边的例子吧,…...

基于arduino uno + L298 的直流电机驱动proteus仿真设计

一、L298简介: L298是一个集成的单片电路,采用15个导线多瓦和PowerSO20封装。它是一个高电压、高电流双全桥驱动器,旨在接受标准TTL逻辑电平和驱动感应负载,如继电器、螺线管、直流和加速电机。提供两个使输入来使独立于输入信号的…...

cola架构:有限状态机(FSM)源码分析

目录 0. cola状态机简述 1.cola状态机使用实例 2.cola状态机源码解析 2.1 语义模型源码 2.1.1 Condition和Action接口 2.1.2 State 2.1.3 Transition接口 2.1.4 StateMachine接口 2.2 Builder模式 2.2.1 StateMachine Builder模式 2.2.2 ExternalTransitionBuilder-…...

通信仿真软件SystemView安装教程(超详细)

介绍 system view是一种电子仿真工具。它是一个信号级的系统仿真软件,主要用于电路与通信系统的设计和仿真,是一个强有力的动态系统分析工具,能满足从数字信号处理,滤波器设计,直到复杂的通信系统等不同层次的设计&am…...

Go学习第八章——面向“对象”编程(入门——结构体与方法)

Go面向“对象”编程(入门——结构体与方法) 1 结构体1.1 快速入门1.2 内存解析1.3 创建结构体四种方法1.4 注意事项和使用细节 2 方法2.1 方法的声明和调用2.2 快速入门案例2.3 调用机制和传参原理2.4 注意事项和细节2.5 方法和函数区别 3 工厂模式 Gola…...

「滚雪球学Java」:方法函数(章节汇总)

🏆本文收录于「滚雪球学Java」专栏,专业攻坚指数级提升,助你一臂之力,带你早日登顶🚀,欢迎大家关注&&收藏!持续更新中,up!up!up!&#xf…...

数据分析必备原理思路(二)

文章目录 三、主流的数据分析方法与框架使用1. 五个数据分析领域关键的理论基础(1)大数定律(2)罗卡定律(3)幸存者偏差(4)辛普森悖论(5)帕累托最优&#xff08…...

分布式ID系统设计(1)

分布式ID系统设计(1) 在分布式服务中,需要对data和message进行唯一标识。 比如订单、支付等。然后在数据库分库分表之后也需要一个唯一id来表示。 基于DB的自增就肯定不能满足了。这个时候能够生成一个Global的唯一ID的服务就很有必要我们姑且把它叫做id-server 。…...

机器学习(python)笔记整理

目录 一、数据预处理: 1. 缺失值处理: 2. 重复值处理: 3. 数据类型: 二、特征工程: 1. 规范化: 2. 归一化: 3. 标准化(方差): 三、训练模型: 如何计算精确度,召…...

微客云霸王餐系统 1.0 : 全面孵化+高额返佣

1、业务简介。业务模式是消费者以5-10元吃到原价15-25元的外卖,底层逻辑是帮外卖商家做推广,解决新店基础销量、老店增加单量、品牌打万单店的需求。 因为外卖店的平均生命周期只有6个月,不断有新店愿意送霸王餐。部分老店也愿意做活动&…...

极智开发 | Hello world for Manim

欢迎关注我的公众号 [极智视界],获取我的更多经验分享 大家好,我是极智视界,本文分享一下 Hello world for Manim。 邀您加入我的知识星球「极智视界」,星球内有超多好玩的项目实战源码和资源下载,链接:https://t.zsxq.com/0aiNxERDq Manim 是什么呢?Manim 是一个用于创…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动(如演唱会、马拉松赛事、高考中考等)期间,城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例,暖城商圈曾因观众集中离场导致周边…...

【单片机期末】单片机系统设计

主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年,作为行业领先的3D工业相机及视觉系统供应商,累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成,通过稳定、易用、高回报的AI3D视觉系统,为汽车、新能源、金属制造等行…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)

骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...

用docker来安装部署freeswitch记录

今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...

均衡后的SNRSINR

本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt​ 根发送天线, n r n_r nr​ 根接收天线的 MIMO 系…...

【C++进阶篇】智能指针

C内存管理终极指南:智能指针从入门到源码剖析 一. 智能指针1.1 auto_ptr1.2 unique_ptr1.3 shared_ptr1.4 make_shared 二. 原理三. shared_ptr循环引用问题三. 线程安全问题四. 内存泄漏4.1 什么是内存泄漏4.2 危害4.3 避免内存泄漏 五. 最后 一. 智能指针 智能指…...

为什么要创建 Vue 实例

核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...