神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频:
2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili
2.总结:
(1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv+2个maxpool+3个linear层
(2)up主整理的train.py等内容里面的细节分析值得学习
(3)对于预测代码的撰写,可以参考代码的predict.py文件
3.几个文件的源代码我都贴一下(都不多——但很精):
(1)首先是 model.py:
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x) # output(16, 14, 14)x = F.relu(self.conv2(x)) # output(32, 10, 10)x = self.pool2(x) # output(32, 5, 5)x = x.view(-1, 32*5*5) # output(32*5*5)x = F.relu(self.fc1(x)) # output(120)x = F.relu(self.fc2(x)) # output(84)x = self.fc3(x) # output(10)return x
模型 == 2个conv + 2个max_pool + 3个linear
(2) train.py训练模型的文件:
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# 定义transform的数据增强transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 处理cifar10的 train和val的数据集的问题# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 训练前的准备: 实例化model网络net , 定义 loss函数 CrossEntropyLoss() 和 Adam优化器net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 开始训练:zero_grad() + outputs + loss backward + optim stepfor epoch in range(5): # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499: # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image) # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 最后把 model的 参数save 为一个.pth文件save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()
分析:数据集划分 + 实例化网络_优化器_loss函数 + 分epoch开始寻 + save_pth权重
(3)predict.py:
import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():# 将需要检测图像 裁剪为32*32transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#实例化网络 + 才入权重net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))# 打开图像,转换格式im = Image.open('1.jpg')im = transform(im) # [C, H, W]im = torch.unsqueeze(im, dim=0) # [N, C, H, W]# 输入到网络中, 得到预测的结果with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()
predict == 处理图像 + 实例化权重 + 得到预测结果
相关文章:
神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频: 2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili 2.总结: (1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv2个maxpool3个linear层 (2)up主整理的train.py…...
SpringMVC 学习(七)之报文信息转换器 HttpMessageConverter
目录 1 HttpMessageConverter 介绍 2 RequestBody 注解 3 ResponseBody 注解 4 RequestEntity 5 ResponseEntity 6 RestController 注解 1 HttpMessageConverter 介绍 HttpMessageConverter 报文信息转换器,将请求报文(如JSON、XML、HTML等&#x…...
浅谈密码学
文章目录 每日一句正能量前言什么是密码学对称加密简述加密语法Kerckhoffs原则常用的加密算法现代密码学的原则威胁模型(按强度增加的顺序) 密码学的应用领域后记 每日一句正能量 人生在世,谁也不能做到让任何人都喜欢,所以没必要…...
Android 混淆是啥玩意儿?
什么是混淆 Android混淆,是伴随着Android系统的流行而产生的一种Android APP保护技术,用于保护APP不被破解和逆向分析。简单的说,就是将原本正常的项目文件,对其类、方法、字段,重新命名a,b,c…之类的字母,…...
【嵌入式——QT】QListWidget
QListWidget类提供了一个基于项的列表小部件,QListWidgetItem是列表中的项,该篇文章中涉及到的功能有添加列表项,插入列表项,删除列表项,清空列表,向上移动列表项,向下移动列表项。 常用API a…...
爬虫入门到精通_基础篇5(PyQuery库_PyQuery说明,初始化,基本CSS选择器,查找元素,遍历,获取信息,DOM操作)
1 PyQuery说明: PyQuery是python中一个强大而又灵活的网页解析库,如果你觉得正则写起来太麻烦,又觉得BeautifulSoup语法太难记,如果你熟悉jQuery的语法那么,PyQuery就是你绝佳的选择。 安装 pip3 install pyquery2 …...
用冒泡排序模拟C语言中的内置快排函数qsort!
目录 编辑 1.回调函数的介绍 2. 回调函数实现转移表 3. 冒泡排序的实现 4. qsort的介绍和使用 5. qsort的模拟实现 6. 完结散花 悟已往之不谏,知来者犹可追 创作不易,宝子们!如果这篇文章对你们有帮助的话,别忘了给个免…...
智慧公厕:打造智慧城市环境卫生新标杆
随着科技的不断发展和城市化进程的加速推进,智慧城市建设已经成为各地政府和企业关注的焦点。而作为智慧城市环境卫生管理的基础设施,智慧公厕的建设和发展也备受重视,被誉为智慧城市的新标杆。本文以智慧公厕源头厂家广州中期科技有限公司&a…...
【学习版】Microsoft Office 2021安装破解教程
本文转载自知乎:https://zhuanlan.zhihu.com/p/655653158 由本人二次整理修改 用到的软件为:Office Tool Plus,下载链接:Office Tool Plus 官方网站 - 一键部署 Office (landian.vip) 下载页面:(随机找个站…...
基于java Springboot实现课程评分系统设计和实现
基于java Springboot实现课程评分系统设计和实现 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源…...
git操作基本指令
1.查看用户名 git config user.name 2.查看密码 git config user.password 3.查看邮箱 git config user.email 4.修改用户名 git config --global user.name "xxx(新用户名)" 5.修改密码 git config --global user.password "xxx(新密码)" 6.修改…...
YOLO算法
YOLO介绍 YOLO,全称为You Only Look Once: Unified, Real-Time Object Detection,是一种实时目标检测算法。目标检测是计算机视觉领域的一个重要任务,它不仅需要识别图像中的物体类别,还需要确定它们的位置。与分类任务只关注对…...
【Android】更改手机主题导致app数据丢失问题
情景:在使用app过程中更改系统主题(比如从浅色主题改为深色主题),这时activity销毁重建了(即走了onPause、onStop、onSaveInstanceState、onDestroy、onCreate、onRestoreInstanceState、onStart、onResume的生命周期&…...
Dell R730 2U服务器实践3:安装英伟达上代专业AI训练Nvidia P4计算卡
Dell R730是一款非常流行的服务器,2U的机箱可以放入两张显卡,这次先用一张英伟达上代专业级AI训练卡:P4卡做实验,本文记录安装过程。 简洁步骤: 打开机箱将P4显卡插在4号槽位关闭机箱安装驱动 详细步骤: 对…...
Nacos环境搭建 -- 服务注册与发现
为什么需要服务治理 在未引入服务治理模块之前,服务之间的通信是服务间直接发起并调用来实现的。只要知道了对应服务的服务名称、IP地址、端口号,就能够发起服务通信。比如A服务的IP地址为192.168.1.100:9000,B服务直接向该IP地址发起请求就…...
Linux了解
简介 Linux是一种自由和开放源代码的类UNIX操作系统,由芬兰的Linus Torvalds于1991年首次发布。Linux最初是作为支持英特尔x86架构的个人电脑的一个自由操作系统,现在已经被移植到更多的计算机硬件平台,如手机、平板电脑、路由器、视频游戏控…...
Keil新版本安装编译器ARMCompiler 5.06
0x00 缘起 我手头的项目在使用最新版本的编译器后,烧录后无法正常运行,故安装5.06,测试后发现程序运行正常,以下为编译器的安装步骤。 0x01 解决方法 1. 下载编译器安装文件,可以去ARM官网下载,也可以使用我…...
【基础训练 || Test-1】
总言 主要内容:一些习题。 文章目录 总言一、选择1、for循环、操作符(逗号表达式)2、格式化输出(转换说明符)3、for循环、操作符(逗号表达式、赋值和判等)4、if语句、操作符ÿ…...
Python读取hbase数据库
1. hbase连接 首先用hbase shell 命令来进入到hbase数据库,然后用list命令来查看hbase下所有表,以其中表“DB_level0”为例,可以看到库名“baotouyiqi”是拼接的,python代码访问时先连接: def hbase_connection(hbase…...
LeetCode41题:缺失的第一个正数(python3)
这道题写的时候完全没有思路,看了很久的题解,才总结出来。 class Solution:def firstMissingPositive(self, nums: List[int]) -> int:nums_set set(nums)n len(nums)for i in range(1, n 1):if i not in nums_set:return ireturn n 1...
探索XScene-UEPlugin:如何实现高斯泼溅模型在虚幻引擎5中的高效可视化与混合渲染
探索XScene-UEPlugin:如何实现高斯泼溅模型在虚幻引擎5中的高效可视化与混合渲染 【免费下载链接】XScene-UEPlugin A Unreal Engine 5 (UE5) based plugin aiming to provide real-time visulization, management, editing, and scalable hybrid rendering of Guas…...
如何完整备份你的QQ空间记忆:GetQzonehistory终极指南
如何完整备份你的QQ空间记忆:GetQzonehistory终极指南 【免费下载链接】GetQzonehistory 获取QQ空间发布的历史说说 项目地址: https://gitcode.com/GitHub_Trending/ge/GetQzonehistory 在数字时代,我们的记忆越来越多地存储在云端。你是否曾担心…...
杰理之app ota升级过程中IO无法维持会掉【篇】
u盘升级则可以维持...
从MCAS系统缺陷看软件安全:一个传感器故障如何导致波音737MAX两次空难?
从MCAS系统缺陷看航空软件安全设计的致命盲区 当一架现代客机以每小时800公里的速度巡航在万米高空时,它的每一个飞行动作背后都有数百万行代码在实时运算。2018年至2019年发生的两起波音737MAX空难,将航空电子系统中一个名为MCAS的软件模块推上了风口浪…...
Phi-4-mini-reasoning 3.8B 轻量模型Python入门实战:零基础快速上手AI推理
Phi-4-mini-reasoning 3.8B 轻量模型Python入门实战:零基础快速上手AI推理 1. 为什么选择Phi-4-mini-reasoning Phi-4-mini-reasoning是一款专为推理任务优化的轻量级大模型,参数规模3.8B,在保持较高推理能力的同时大幅降低了硬件需求。对于…...
从零开始:用CloudCompare完成平面距离测量的完整工作流
从零开始:用CloudCompare完成平面距离测量的完整工作流 在三维数据处理领域,精确测量平面间的距离是许多工程和科研项目的关键步骤。无论是建筑行业的BIM模型验证,还是制造业的质量控制,亦或是地质勘探中的层位分析,都…...
Unity 2023.2 项目升级C# 9.0?先看看这5个不支持的语法特性(附替代方案)
Unity 2023.2项目升级C# 9.0避坑指南:5个不支持的语法特性与实战解决方案 当你将Unity项目升级到2023.2版本,发现IDE智能提示中闪烁着诱人的C# 9.0新特性时,先别急着重构代码。上周我的团队就遭遇了这样的场景:在将大型项目迁移到…...
Android + OpenCV 实战指南:从环境搭建到图像处理(超详细)
1. Android与OpenCV环境搭建全攻略 第一次接触OpenCV的Android开发者往往会卡在环境配置这一步。我当年踩过的坑现在可以帮你完美避开。OpenCV作为计算机视觉领域的瑞士军刀,在移动端同样能发挥强大威力,但首先得让它跑起来。 核心工具准备: …...
从医学影像数据到三维可视化:MRIcroGL如何改变你的研究流程
从医学影像数据到三维可视化:MRIcroGL如何改变你的研究流程 【免费下载链接】MRIcroGL v1.2 GLSL volume rendering. Able to view NIfTI, DICOM, MGH, MHD, NRRD, AFNI format images. 项目地址: https://gitcode.com/gh_mirrors/mr/MRIcroGL 你是否曾经面对…...
QMCDecode:3步解锁QQ音乐加密文件,让音乐真正属于你
QMCDecode:3步解锁QQ音乐加密文件,让音乐真正属于你 【免费下载链接】QMCDecode QQ音乐QMC格式转换为普通格式(qmcflac转flac,qmc0,qmc3转mp3, mflac,mflac0等转flac),仅支持macOS,可自动识别到QQ音乐下载目录…...
