PyTorch搭建LeNet训练集详细实现
一、下载训练集
导包
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
ToTensor()函数:
把图像[heigh x width x channels] 转换为 [channels x height x width]
Normalize() 数据标准化函数:
最后一行是标准化数值计算公式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
参数解释:
root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建
train=True:当前为训练集
download=True:下载数据集时设置为True,下载完成后改为False
transform=transform :设置对图像进行预处理的函数
运行下载数据集结果为:
下载完成后生成了data文件夹
二、导入训练集
# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True, num_workers=0)
参数解释:
trainset:把刚刚下载的数据导入进来
batch_size=36:一批数据的大小
shuffle=True:训练集中的数据是否打乱(一般默认打乱)
num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.
三、下载测试集
# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()classes = ('plane', 'car', 'bird', 'cat', # 数据集中的分类,设置为元组,不可变类'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
参数解释:
test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。
四、查看导入的图片
在中间过程打印图片进行查看,后续会注释掉
def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))
运行结果
图片很模糊,因为像素很低。
上面识别出来的结果都对了。
我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。
pip install numpy==1.23.2
这个也只是中间过程,后续会注释或者删了。
五、将创建的模型实例化
创建模型请看PyTorch搭建LeNet神经网络-CSDN博客
# 将创建的模型实例化
net = LeNet() # 实例化
loss_fuction = nn.CrossEntropyLoss() # 定义损失函数# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):running_loss = 0.0 # 用来累加在学习过程中的损失for step, data in enumerate(trainloader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad() # 历时损失梯度清零。# forward + backward + optimizeoutputs = net(inputs)loss = loss_fuction(outputs, labels) # 计算神经网络的预测值和真实标签之间的损失loss.backward()optimizer.step() # step()函数实现参数更新# print statistics 打印数据的过程running_loss += loss.item()if step % 500 == 499: # 每隔500步打印一次数据的信息with torch.no_grad(): # 上下文管理器outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
详细解释:
比较重点的单独解释了,其他的在注释中。
optimizer.zero_grad() # 历时损失梯度清零。
? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数
=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)
with torch.no_grad(): # 上下文管理器
上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。
如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。
print函数中打印参数解释:
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))
epoch + 1:迭代到第几轮了
step + 1:某一轮的第几步
running_loss / 500:训练过程中500步平均训练误差
accuracy:准确率
运行结果
相关文章:

PyTorch搭建LeNet训练集详细实现
一、下载训练集 导包 import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as npToTensor()函数: 把图像…...

R语言复现:中国Charls数据库一篇现况调查论文的缺失数据填补方法
编者 在临床研究中,数据缺失是不可避免的,甚至没有缺失,数据的真实性都会受到质疑。 那我们该如何应对缺失的数据?放着不管?还是重新开始?不妨试着对缺失值进行填补,简单又高效。毕竟对于统计师来说&#…...

解决Git:Author identity unknown Please tell me who you are.
报错信息: 意思: 作者身份未知 ***请告诉我你是谁。 解决办法: git config --global user.name "你的名字"git config --global user.email "你的邮箱"...
Flink StreamTask启动和执行源码分析
文章目录 前言StreamTask 部署启动Task 线程启动StreamTask 初始化StreamTask 执行 前言 Flink的StreamTask的启动和执行是一个复杂的过程,涉及多个关键步骤。以下是StreamTask启动和执行的主要流程: 初始化:StreamTask的初始化阶段涉及多个…...
【MySQL 系列】MySQL 语句篇_DCL 语句
DCL( Data Control Language,数据控制语言)用于对数据访问权限进行控制,定义数据库、表、字段、用户的访问权限和安全级别。主要关键字包括 GRANT、 REVOKE 等。 文章目录 1、MySQL 中的 DCL 语句1.1、数据控制语言--DCL1.2、MySQ…...

什么是序列化?为什么需要序列化?
1、典型回答 序列化(Serialization)序列化是将对象转换为可存储或传输的形式的过程(例如: 将对象转换为字节流) 反序列化(Deserialization) 是将序列化后的数据(例如: 二进制文件)转换回原始对象的过程。通过反序列化,可以从存储介质 (如磁盘、数据库) 或通过网络…...

Linux本地搭建FastDFS系统
文章目录 前言1. 本地搭建FastDFS文件系统1.1 环境安装1.2 安装libfastcommon1.3 安装FastDFS1.4 配置Tracker1.5 配置Storage1.6 测试上传下载1.7 与Nginx整合1.8 安装Nginx1.9 配置Nginx 2. 局域网测试访问FastDFS3. 安装cpolar内网穿透4. 配置公网访问地址5. 固定公网地址5.…...
docker和docker-compose安装
一、docker安装 1、移除旧版本 依次执行如下命令移除旧版本docker,如未安装过无需执行 yum -y remove docker docker-client docker-client-latest docker-common docker-latest docker-latest-logrotate docker-logrotate docker-selinux docker-engine-selinux…...
深入理解Spring的ApplicationContext:案例详解与应用
深入理解Spring的ApplicationContext:案例详解与应用 在Spring框架的丰富生态中,ApplicationContext扮演着至关重要的角色。作为BeanFactory的扩展,ApplicationContext不仅继承了其所有功能,还引入了更多高级特性,使得…...

6.Java并发编程—深入剖析Java Executors:探索创建线程的5种神奇方式
Executors快速创建线程池的方法 Java通过Executors 工厂提供了5种创建线程池的方法,具体方法如下 方法名描述newSingleThreadExecutor()创建一个单线程的线程池,该线程池中只有一个工作线程。所有任务按照提交的顺序依次执行,保证任务的顺序性…...
英语阅读挑战
英语阅读真是令人头痛的东西。可怜的子航想利用寒假时间突破英语难题。当他拿到一篇英语阅读时,他很好奇作者最喜欢用那些字母。 输入 一句30词以内的英语句子 输出 统计每个字母出现的次数 样例输入 复制 However,the British dont have a history of exporting th…...
备战蓝桥之思维
平台重叠真的坑 给你一句样例,如果你觉得自己的代码没问题那就试试吧 2 1 1 3 1 0 4 正确答案 0 0 0 0 P1105 平台 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) import java.awt.Checkbox; import java.awt.PageAttributes.OriginType; import java.io.B…...
09 string的实现
注意 实现仿cplus官网的的string类,对部分主要功能实现 实现 头文件 #pragma once #include <iostream> #include <assert.h> #include <string>namespace mystring {class string{friend std::ostream& operator<<(std::ostream&a…...
Git 进行版本控制时,配置 user.name 和 user.email
在使用 Git 进行版本控制时,配置 user.name 和 user.email 是一个非常重要的初始步骤,但不是绝对必须的。这两个配置项定义了当你进行提交(commit)时用于标识提交者的信息。 为什么建议配置 user.name 和 user.email 标识提交者…...
传统开发读写优化与HBase
目录: 一、传统开发数据读写性能优化 1. Mysql 分表、主从复制与读写分离 2. Redis(缓存型数据库)主从复制与读写分离 二、HBase 一、传统开发数据读写性能优化 1、Mysql 分表、主从复制与读写分离 mysql分库分表方案 一种分表方案:设置表A 表B 表A 自增列从1开始…...

【OpenGL实现 03】纹理贴图原理和实现
目录 一、说明二、纹理贴图原理2.1 纹理融合原理2.2 UV坐标原理 三、生成纹理对象3.1 需要在VAO上绑定纹理坐标3.2 纹理传递3.3 纹理buffer生成 四、代码实现:五、着色器4.1 片段4.2 顶点 五、后记 一、说明 本篇叙述在画出图元的时候,如何贴图纹理图片…...

FDU 2021 | 二叉树关键节点的个数
文章目录 1. 题目描述2. 我的尝试 1. 题目描述 给定一颗二叉树,树的每个节点的值为一个正整数。如果从根节点到节点 N 的路径上不存在比节点 N 的值大的节点,那么节点 N 被认为是树上的关键节点。求树上所有的关键节点的个数。请写出程序,并…...

精读《React Conf 2019 - Day2》
1 引言 这是继 精读《React Conf 2019 - Day1》 之后的第二篇,补充了 React Conf 2019 第二天的内容。 2 概述 & 精读 第二天的内容更为精彩,笔者会重点介绍比较干货的部分。 Fast refresh Fast refresh 是更好的 react-hot-loader 替代方案&am…...
向ChatGPT高效提问模板
PS: ChatGPT无限次数,无需魔法,登录即可使用,网页打开下面 tj4.mnsfdx.net [点击跳转链接](http://tj4.mnsfdx.net/) 我想请你XXXX,请问我应该如何向你提问才能得到最满意的答案,请提供全面、详细的建议,针对每一个建…...
android metaRTC编译
参考文章: metaRTC3.0稳定版本编译指南_metartc 编译-CSDN博客 源码下载: Releases metartc/metaRTC GitHub 版本v6.0-b4即可...
小白初学SpringBoot记录
1.对于通过json返回用户信息时,需要忽略password字段操作: 1.1 pom配置jackson细节: <dependency><groupId>com.fasterxml.jackson.core</groupId><artifactId>jackson-databind</artifactId><version>…...
qt使用笔记二:main.cpp详解
Qt中main.cpp文件详解 main.cpp是Qt应用程序的入口文件,包含程序的启动逻辑。下面我将详细解析其结构和功能。 基本结构 一个典型的Qt main.cpp 文件结构如下: #include <QApplication> // 或者 QGuiApplication/QCoreApplication #include &…...

如何打造一款金融推理工具Financial Reasoning Workflow:WebUI+Ollama+Fin-R1+MCP/RAG
在之前的文章中,我探讨了如何使用具身人工智能,让大语言模型智能体来模仿[当今著名对冲基金经理的投资策略]。 在本文中,我将探讨另一种方法,该方法结合了经过金融推理训练的特定大语言模型(LLM)࿰…...
针对“仅某个地区出现Bug”的原因分析与解决方案
一、核心排查方向(按优先级排序) 地区相关配置差异 检查点: 该地区是否有独立的配置文件或数据库分片?是否启用了地区特定的功能开关(Feature Flag)或AB测试?本地化内容(如语言、时…...
JDK21深度解密 Day 15:JDK21实战最佳实践总结
【JDK21深度解密 Day 15】JDK21实战最佳实践总结 文章简述 本篇文章是《JDK21深度解密:从新特性到生产实践的全栈指南》系列的第15篇,聚焦于JDK21实战最佳实践总结。作为Java历史上最重要的LTS版本之一,JDK21带来了虚拟线程、结构化并发、模式匹配、ZGC优化等革命性特性,…...

阿里云服务器安装nginx并配置前端资源路径(前后端部署到一台服务器并成功访问)
运行以下命令,安装Nginx相关依赖。 yum install -y gcc-c yum install -y pcre pcre-devel yum install -y zlib zlib-devel yum install -y openssl openssl-devel 运行wget命令下载Nginx 1.21.6。 您可以通过Nginx开源社区直接获取对应版本的安装包URL&…...

2025软件供应链安全最佳实践︱证券DevSecOps下供应链与开源治理实践
项目背景:近年来,云计算、AI人工智能、大数据等信息技术的不断发展、各行各业的信息电子化的步伐不断加快、信息化的水平不断提高,网络安全的风险不断累积,金融证券行业面临着越来越多的威胁挑战。特别是近年以来,开源…...
Flask 基础与实战概述
一、Flask 基础知识 什么是 Flask? Flask 是一个基于 Python 的轻量级 Web 框架(微框架)。 特点:核心代码简洁,给予开发者更多选择空间。 与 Django 对比: Django 创建空项目生成多个文件,Flask 仅需一个文件即可实现简单应用(如 "Hello, World!")。 Flask …...

【软件工具】批量OCR指定区域图片自动识别内容重命名软件使用教程及注意事项
批量OCR指定区域图片自动识别内容重命名软件使用教程及注意事项 1、操作步骤1-5: 安装与启动:安装成功后,在桌面或开始菜单找到软件图标,双击启动。 导入图片:进入软件主界面,点击 “导入图片” 按钮&a…...

手写Promise.all
前言 之前在看远方os大佬直播的时候看到有让手写的Promise.all的问题,然后心血来潮自己准备手写一个 开始 首先,我们需要明确原本js提供的Promise.all的特性 Promise.all返回的是一个Promise如果传入的数据中有一个reject即整个all返回的就是reject&…...