神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频:
2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili
2.总结:
(1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv+2个maxpool+3个linear层
(2)up主整理的train.py等内容里面的细节分析值得学习
(3)对于预测代码的撰写,可以参考代码的predict.py文件
3.几个文件的源代码我都贴一下(都不多——但很精):
(1)首先是 model.py:
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x) # output(16, 14, 14)x = F.relu(self.conv2(x)) # output(32, 10, 10)x = self.pool2(x) # output(32, 5, 5)x = x.view(-1, 32*5*5) # output(32*5*5)x = F.relu(self.fc1(x)) # output(120)x = F.relu(self.fc2(x)) # output(84)x = self.fc3(x) # output(10)return x
模型 == 2个conv + 2个max_pool + 3个linear
(2) train.py训练模型的文件:
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# 定义transform的数据增强transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 处理cifar10的 train和val的数据集的问题# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 训练前的准备: 实例化model网络net , 定义 loss函数 CrossEntropyLoss() 和 Adam优化器net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 开始训练:zero_grad() + outputs + loss backward + optim stepfor epoch in range(5): # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499: # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image) # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 最后把 model的 参数save 为一个.pth文件save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()
分析:数据集划分 + 实例化网络_优化器_loss函数 + 分epoch开始寻 + save_pth权重
(3)predict.py:
import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():# 将需要检测图像 裁剪为32*32transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#实例化网络 + 才入权重net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))# 打开图像,转换格式im = Image.open('1.jpg')im = transform(im) # [C, H, W]im = torch.unsqueeze(im, dim=0) # [N, C, H, W]# 输入到网络中, 得到预测的结果with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()
predict == 处理图像 + 实例化权重 + 得到预测结果
相关文章:
神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频: 2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili 2.总结: (1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv2个maxpool3个linear层 (2)up主整理的train.py…...
SpringMVC 学习(七)之报文信息转换器 HttpMessageConverter
目录 1 HttpMessageConverter 介绍 2 RequestBody 注解 3 ResponseBody 注解 4 RequestEntity 5 ResponseEntity 6 RestController 注解 1 HttpMessageConverter 介绍 HttpMessageConverter 报文信息转换器,将请求报文(如JSON、XML、HTML等&#x…...
浅谈密码学
文章目录 每日一句正能量前言什么是密码学对称加密简述加密语法Kerckhoffs原则常用的加密算法现代密码学的原则威胁模型(按强度增加的顺序) 密码学的应用领域后记 每日一句正能量 人生在世,谁也不能做到让任何人都喜欢,所以没必要…...
Android 混淆是啥玩意儿?
什么是混淆 Android混淆,是伴随着Android系统的流行而产生的一种Android APP保护技术,用于保护APP不被破解和逆向分析。简单的说,就是将原本正常的项目文件,对其类、方法、字段,重新命名a,b,c…之类的字母,…...
【嵌入式——QT】QListWidget
QListWidget类提供了一个基于项的列表小部件,QListWidgetItem是列表中的项,该篇文章中涉及到的功能有添加列表项,插入列表项,删除列表项,清空列表,向上移动列表项,向下移动列表项。 常用API a…...
爬虫入门到精通_基础篇5(PyQuery库_PyQuery说明,初始化,基本CSS选择器,查找元素,遍历,获取信息,DOM操作)
1 PyQuery说明: PyQuery是python中一个强大而又灵活的网页解析库,如果你觉得正则写起来太麻烦,又觉得BeautifulSoup语法太难记,如果你熟悉jQuery的语法那么,PyQuery就是你绝佳的选择。 安装 pip3 install pyquery2 …...
用冒泡排序模拟C语言中的内置快排函数qsort!
目录 编辑 1.回调函数的介绍 2. 回调函数实现转移表 3. 冒泡排序的实现 4. qsort的介绍和使用 5. qsort的模拟实现 6. 完结散花 悟已往之不谏,知来者犹可追 创作不易,宝子们!如果这篇文章对你们有帮助的话,别忘了给个免…...
智慧公厕:打造智慧城市环境卫生新标杆
随着科技的不断发展和城市化进程的加速推进,智慧城市建设已经成为各地政府和企业关注的焦点。而作为智慧城市环境卫生管理的基础设施,智慧公厕的建设和发展也备受重视,被誉为智慧城市的新标杆。本文以智慧公厕源头厂家广州中期科技有限公司&a…...
【学习版】Microsoft Office 2021安装破解教程
本文转载自知乎:https://zhuanlan.zhihu.com/p/655653158 由本人二次整理修改 用到的软件为:Office Tool Plus,下载链接:Office Tool Plus 官方网站 - 一键部署 Office (landian.vip) 下载页面:(随机找个站…...
基于java Springboot实现课程评分系统设计和实现
基于java Springboot实现课程评分系统设计和实现 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源…...
git操作基本指令
1.查看用户名 git config user.name 2.查看密码 git config user.password 3.查看邮箱 git config user.email 4.修改用户名 git config --global user.name "xxx(新用户名)" 5.修改密码 git config --global user.password "xxx(新密码)" 6.修改…...
YOLO算法
YOLO介绍 YOLO,全称为You Only Look Once: Unified, Real-Time Object Detection,是一种实时目标检测算法。目标检测是计算机视觉领域的一个重要任务,它不仅需要识别图像中的物体类别,还需要确定它们的位置。与分类任务只关注对…...
【Android】更改手机主题导致app数据丢失问题
情景:在使用app过程中更改系统主题(比如从浅色主题改为深色主题),这时activity销毁重建了(即走了onPause、onStop、onSaveInstanceState、onDestroy、onCreate、onRestoreInstanceState、onStart、onResume的生命周期&…...
Dell R730 2U服务器实践3:安装英伟达上代专业AI训练Nvidia P4计算卡
Dell R730是一款非常流行的服务器,2U的机箱可以放入两张显卡,这次先用一张英伟达上代专业级AI训练卡:P4卡做实验,本文记录安装过程。 简洁步骤: 打开机箱将P4显卡插在4号槽位关闭机箱安装驱动 详细步骤: 对…...
Nacos环境搭建 -- 服务注册与发现
为什么需要服务治理 在未引入服务治理模块之前,服务之间的通信是服务间直接发起并调用来实现的。只要知道了对应服务的服务名称、IP地址、端口号,就能够发起服务通信。比如A服务的IP地址为192.168.1.100:9000,B服务直接向该IP地址发起请求就…...
Linux了解
简介 Linux是一种自由和开放源代码的类UNIX操作系统,由芬兰的Linus Torvalds于1991年首次发布。Linux最初是作为支持英特尔x86架构的个人电脑的一个自由操作系统,现在已经被移植到更多的计算机硬件平台,如手机、平板电脑、路由器、视频游戏控…...
Keil新版本安装编译器ARMCompiler 5.06
0x00 缘起 我手头的项目在使用最新版本的编译器后,烧录后无法正常运行,故安装5.06,测试后发现程序运行正常,以下为编译器的安装步骤。 0x01 解决方法 1. 下载编译器安装文件,可以去ARM官网下载,也可以使用我…...
【基础训练 || Test-1】
总言 主要内容:一些习题。 文章目录 总言一、选择1、for循环、操作符(逗号表达式)2、格式化输出(转换说明符)3、for循环、操作符(逗号表达式、赋值和判等)4、if语句、操作符ÿ…...
Python读取hbase数据库
1. hbase连接 首先用hbase shell 命令来进入到hbase数据库,然后用list命令来查看hbase下所有表,以其中表“DB_level0”为例,可以看到库名“baotouyiqi”是拼接的,python代码访问时先连接: def hbase_connection(hbase…...
LeetCode41题:缺失的第一个正数(python3)
这道题写的时候完全没有思路,看了很久的题解,才总结出来。 class Solution:def firstMissingPositive(self, nums: List[int]) -> int:nums_set set(nums)n len(nums)for i in range(1, n 1):if i not in nums_set:return ireturn n 1...
日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻
在如今就业市场竞争日益激烈的背景下,越来越多的求职者将目光投向了日本及中日双语岗位。但是,一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧?面对生疏的日语交流环境,即便提前恶补了…...
iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘
美国西海岸的夏天,再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至,这不仅是开发者的盛宴,更是全球数亿苹果用户翘首以盼的科技春晚。今年,苹果依旧为我们带来了全家桶式的系统更新,包括 iOS 26、iPadOS 26…...
脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)
一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...
【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密
在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
Nginx server_name 配置说明
Nginx 是一个高性能的反向代理和负载均衡服务器,其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机(Virtual Host)。 1. 简介 Nginx 使用 server_name 指令来确定…...
C++ 基础特性深度解析
目录 引言 一、命名空间(namespace) C 中的命名空间 与 C 语言的对比 二、缺省参数 C 中的缺省参数 与 C 语言的对比 三、引用(reference) C 中的引用 与 C 语言的对比 四、inline(内联函数…...
什么是EULA和DPA
文章目录 EULA(End User License Agreement)DPA(Data Protection Agreement)一、定义与背景二、核心内容三、法律效力与责任四、实际应用与意义 EULA(End User License Agreement) 定义: EULA即…...
MySQL JOIN 表过多的优化思路
当 MySQL 查询涉及大量表 JOIN 时,性能会显著下降。以下是优化思路和简易实现方法: 一、核心优化思路 减少 JOIN 数量 数据冗余:添加必要的冗余字段(如订单表直接存储用户名)合并表:将频繁关联的小表合并成…...
k8s从入门到放弃之HPA控制器
k8s从入门到放弃之HPA控制器 Kubernetes中的Horizontal Pod Autoscaler (HPA)控制器是一种用于自动扩展部署、副本集或复制控制器中Pod数量的机制。它可以根据观察到的CPU利用率(或其他自定义指标)来调整这些对象的规模,从而帮助应用程序在负…...
