当前位置: 首页 > news >正文

PyTorch搭建LeNet训练集详细实现

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

# 将创建的模型实例化
net = LeNet()  # 实例化
loss_fuction = nn.CrossEntropyLoss()  # 定义损失函数# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):running_loss = 0.0  # 用来累加在学习过程中的损失for step, data in enumerate(trainloader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()   # 历时损失梯度清零。# forward + backward + optimizeoutputs = net(inputs)loss = loss_fuction(outputs, labels)  # 计算神经网络的预测值和真实标签之间的损失loss.backward()optimizer.step()  # step()函数实现参数更新# print statistics  打印数据的过程running_loss += loss.item()if step % 500 == 499:  # 每隔500步打印一次数据的信息with torch.no_grad():  # 上下文管理器outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

 详细解释:

比较重点的单独解释了,其他的在注释中。

 optimizer.zero_grad()   # 历时损失梯度清零。

 ? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数

=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)

 with torch.no_grad():  # 上下文管理器

 上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。

如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。

print函数中打印参数解释:

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))

epoch + 1:迭代到第几轮了

step + 1:某一轮的第几步

running_loss / 500:训练过程中500步平均训练误差

accuracy:准确率

运行结果

相关文章:

PyTorch搭建LeNet训练集详细实现

一、下载训练集 导包 import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as npToTensor()函数: 把图像…...

R语言复现:中国Charls数据库一篇现况调查论文的缺失数据填补方法

编者 在临床研究中,数据缺失是不可避免的,甚至没有缺失,数据的真实性都会受到质疑。 那我们该如何应对缺失的数据?放着不管?还是重新开始?不妨试着对缺失值进行填补,简单又高效。毕竟对于统计师来说&#…...

解决Git:Author identity unknown Please tell me who you are.

报错信息: 意思: 作者身份未知 ***请告诉我你是谁。 解决办法: git config --global user.name "你的名字"git config --global user.email "你的邮箱"...

Flink StreamTask启动和执行源码分析

文章目录 前言StreamTask 部署启动Task 线程启动StreamTask 初始化StreamTask 执行 前言 Flink的StreamTask的启动和执行是一个复杂的过程,涉及多个关键步骤。以下是StreamTask启动和执行的主要流程: 初始化:StreamTask的初始化阶段涉及多个…...

【MySQL 系列】MySQL 语句篇_DCL 语句

DCL( Data Control Language,数据控制语言)用于对数据访问权限进行控制,定义数据库、表、字段、用户的访问权限和安全级别。主要关键字包括 GRANT、 REVOKE 等。 文章目录 1、MySQL 中的 DCL 语句1.1、数据控制语言--DCL1.2、MySQ…...

什么是序列化?为什么需要序列化?

1、典型回答 序列化(Serialization)序列化是将对象转换为可存储或传输的形式的过程(例如: 将对象转换为字节流) 反序列化(Deserialization) 是将序列化后的数据(例如: 二进制文件)转换回原始对象的过程。通过反序列化,可以从存储介质 (如磁盘、数据库) 或通过网络…...

Linux本地搭建FastDFS系统

文章目录 前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx 2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.…...

docker和docker-compose安装

一、docker安装 1、移除旧版本 依次执行如下命令移除旧版本docker,如未安装过无需执行 yum -y remove docker docker-client docker-client-latest docker-common docker-latest docker-latest-logrotate docker-logrotate docker-selinux docker-engine-selinux…...

深入理解Spring的ApplicationContext:案例详解与应用

深入理解Spring的ApplicationContext:案例详解与应用 在Spring框架的丰富生态中,ApplicationContext扮演着至关重要的角色。作为BeanFactory的扩展,ApplicationContext不仅继承了其所有功能,还引入了更多高级特性,使得…...

6.Java并发编程—深入剖析Java Executors:探索创建线程的5种神奇方式

Executors快速创建线程池的方法 Java通过Executors 工厂提供了5种创建线程池的方法,具体方法如下 方法名描述newSingleThreadExecutor()创建一个单线程的线程池,该线程池中只有一个工作线程。所有任务按照提交的顺序依次执行,保证任务的顺序性…...

英语阅读挑战

英语阅读真是令人头痛的东西。可怜的子航想利用寒假时间突破英语难题。当他拿到一篇英语阅读时,他很好奇作者最喜欢用那些字母。 输入 一句30词以内的英语句子 输出 统计每个字母出现的次数 样例输入 复制 However,the British dont have a history of exporting th…...

备战蓝桥之思维

平台重叠真的坑 给你一句样例,如果你觉得自己的代码没问题那就试试吧 2 1 1 3 1 0 4 正确答案 0 0 0 0 P1105 平台 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) import java.awt.Checkbox; import java.awt.PageAttributes.OriginType; import java.io.B…...

09 string的实现

注意 实现仿cplus官网的的string类&#xff0c;对部分主要功能实现 实现 头文件 #pragma once #include <iostream> #include <assert.h> #include <string>namespace mystring {class string{friend std::ostream& operator<<(std::ostream&a…...

Git 进行版本控制时,配置 user.name 和 user.email

在使用 Git 进行版本控制时&#xff0c;配置 user.name 和 user.email 是一个非常重要的初始步骤&#xff0c;但不是绝对必须的。这两个配置项定义了当你进行提交&#xff08;commit&#xff09;时用于标识提交者的信息。 为什么建议配置 user.name 和 user.email 标识提交者…...

传统开发读写优化与HBase

目录: 一、传统开发数据读写性能优化 1. Mysql 分表、主从复制与读写分离 2. Redis(缓存型数据库)主从复制与读写分离 二、HBase 一、传统开发数据读写性能优化 1、Mysql 分表、主从复制与读写分离 mysql分库分表方案 一种分表方案&#xff1a;设置表A 表B 表A 自增列从1开始…...

【OpenGL实现 03】纹理贴图原理和实现

目录 一、说明二、纹理贴图原理2.1 纹理融合原理2.2 UV坐标原理 三、生成纹理对象3.1 需要在VAO上绑定纹理坐标3.2 纹理传递3.3 纹理buffer生成 四、代码实现&#xff1a;五、着色器4.1 片段4.2 顶点 五、后记 一、说明 本篇叙述在画出图元的时候&#xff0c;如何贴图纹理图片…...

FDU 2021 | 二叉树关键节点的个数

文章目录 1. 题目描述2. 我的尝试 1. 题目描述 给定一颗二叉树&#xff0c;树的每个节点的值为一个正整数。如果从根节点到节点 N 的路径上不存在比节点 N 的值大的节点&#xff0c;那么节点 N 被认为是树上的关键节点。求树上所有的关键节点的个数。请写出程序&#xff0c;并…...

精读《React Conf 2019 - Day2》

1 引言 这是继 精读《React Conf 2019 - Day1》 之后的第二篇&#xff0c;补充了 React Conf 2019 第二天的内容。 2 概述 & 精读 第二天的内容更为精彩&#xff0c;笔者会重点介绍比较干货的部分。 Fast refresh Fast refresh 是更好的 react-hot-loader 替代方案&am…...

向ChatGPT高效提问模板

PS: ChatGPT无限次数&#xff0c;无需魔法&#xff0c;登录即可使用,网页打开下面 tj4.mnsfdx.net [点击跳转链接](http://tj4.mnsfdx.net/) 我想请你XXXX&#xff0c;请问我应该如何向你提问才能得到最满意的答案&#xff0c;请提供全面、详细的建议&#xff0c;针对每一个建…...

android metaRTC编译

参考文章&#xff1a; metaRTC3.0稳定版本编译指南_metartc 编译-CSDN博客 源码下载&#xff1a; Releases metartc/metaRTC GitHub 版本v6.0-b4即可...

深度学习在微纳光子学中的应用

深度学习在微纳光子学中的主要应用方向 深度学习与微纳光子学的结合主要集中在以下几个方向&#xff1a; 逆向设计 通过神经网络快速预测微纳结构的光学响应&#xff0c;替代传统耗时的数值模拟方法。例如设计超表面、光子晶体等结构。 特征提取与优化 从复杂的光学数据中自…...

Chapter03-Authentication vulnerabilities

文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

ssc377d修改flash分区大小

1、flash的分区默认分配16M、 / # df -h Filesystem Size Used Available Use% Mounted on /dev/root 1.9M 1.9M 0 100% / /dev/mtdblock4 3.0M...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

基础测试工具使用经验

背景 vtune&#xff0c;perf, nsight system等基础测试工具&#xff0c;都是用过的&#xff0c;但是没有记录&#xff0c;都逐渐忘了。所以写这篇博客总结记录一下&#xff0c;只要以后发现新的用法&#xff0c;就记得来编辑补充一下 perf 比较基础的用法&#xff1a; 先改这…...

Python爬虫(二):爬虫完整流程

爬虫完整流程详解&#xff08;7大核心步骤实战技巧&#xff09; 一、爬虫完整工作流程 以下是爬虫开发的完整流程&#xff0c;我将结合具体技术点和实战经验展开说明&#xff1a; 1. 目标分析与前期准备 网站技术分析&#xff1a; 使用浏览器开发者工具&#xff08;F12&…...

Python如何给视频添加音频和字幕

在Python中&#xff0c;给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加&#xff0c;包括必要的代码示例和详细解释。 环境准备 在开始之前&#xff0c;需要安装以下Python库&#xff1a;…...

《基于Apache Flink的流处理》笔记

思维导图 1-3 章 4-7章 8-11 章 参考资料 源码&#xff1a; https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...

【Redis】笔记|第8节|大厂高并发缓存架构实战与优化

缓存架构 代码结构 代码详情 功能点&#xff1a; 多级缓存&#xff0c;先查本地缓存&#xff0c;再查Redis&#xff0c;最后才查数据库热点数据重建逻辑使用分布式锁&#xff0c;二次查询更新缓存采用读写锁提升性能采用Redis的发布订阅机制通知所有实例更新本地缓存适用读多…...

GitHub 趋势日报 (2025年06月06日)

&#x1f4ca; 由 TrendForge 系统生成 | &#x1f310; https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日获星趋势图 今日获星趋势图 590 cognee 551 onlook 399 project-based-learning 348 build-your-own-x 320 ne…...