神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频:
2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili
2.总结:
(1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv+2个maxpool+3个linear层
(2)up主整理的train.py等内容里面的细节分析值得学习
(3)对于预测代码的撰写,可以参考代码的predict.py文件
3.几个文件的源代码我都贴一下(都不多——但很精):
(1)首先是 model.py:
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x) # output(16, 14, 14)x = F.relu(self.conv2(x)) # output(32, 10, 10)x = self.pool2(x) # output(32, 5, 5)x = x.view(-1, 32*5*5) # output(32*5*5)x = F.relu(self.fc1(x)) # output(120)x = F.relu(self.fc2(x)) # output(84)x = self.fc3(x) # output(10)return x
模型 == 2个conv + 2个max_pool + 3个linear
(2) train.py训练模型的文件:
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# 定义transform的数据增强transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 处理cifar10的 train和val的数据集的问题# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 训练前的准备: 实例化model网络net , 定义 loss函数 CrossEntropyLoss() 和 Adam优化器net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 开始训练:zero_grad() + outputs + loss backward + optim stepfor epoch in range(5): # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499: # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image) # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 最后把 model的 参数save 为一个.pth文件save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()
分析:数据集划分 + 实例化网络_优化器_loss函数 + 分epoch开始寻 + save_pth权重
(3)predict.py:
import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():# 将需要检测图像 裁剪为32*32transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#实例化网络 + 才入权重net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))# 打开图像,转换格式im = Image.open('1.jpg')im = transform(im) # [C, H, W]im = torch.unsqueeze(im, dim=0) # [N, C, H, W]# 输入到网络中, 得到预测的结果with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()
predict == 处理图像 + 实例化权重 + 得到预测结果
相关文章:
神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频: 2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili 2.总结: (1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv2个maxpool3个linear层 (2)up主整理的train.py…...
SpringMVC 学习(七)之报文信息转换器 HttpMessageConverter
目录 1 HttpMessageConverter 介绍 2 RequestBody 注解 3 ResponseBody 注解 4 RequestEntity 5 ResponseEntity 6 RestController 注解 1 HttpMessageConverter 介绍 HttpMessageConverter 报文信息转换器,将请求报文(如JSON、XML、HTML等&#x…...
浅谈密码学
文章目录 每日一句正能量前言什么是密码学对称加密简述加密语法Kerckhoffs原则常用的加密算法现代密码学的原则威胁模型(按强度增加的顺序) 密码学的应用领域后记 每日一句正能量 人生在世,谁也不能做到让任何人都喜欢,所以没必要…...
Android 混淆是啥玩意儿?
什么是混淆 Android混淆,是伴随着Android系统的流行而产生的一种Android APP保护技术,用于保护APP不被破解和逆向分析。简单的说,就是将原本正常的项目文件,对其类、方法、字段,重新命名a,b,c…之类的字母,…...
【嵌入式——QT】QListWidget
QListWidget类提供了一个基于项的列表小部件,QListWidgetItem是列表中的项,该篇文章中涉及到的功能有添加列表项,插入列表项,删除列表项,清空列表,向上移动列表项,向下移动列表项。 常用API a…...
爬虫入门到精通_基础篇5(PyQuery库_PyQuery说明,初始化,基本CSS选择器,查找元素,遍历,获取信息,DOM操作)
1 PyQuery说明: PyQuery是python中一个强大而又灵活的网页解析库,如果你觉得正则写起来太麻烦,又觉得BeautifulSoup语法太难记,如果你熟悉jQuery的语法那么,PyQuery就是你绝佳的选择。 安装 pip3 install pyquery2 …...
用冒泡排序模拟C语言中的内置快排函数qsort!
目录 编辑 1.回调函数的介绍 2. 回调函数实现转移表 3. 冒泡排序的实现 4. qsort的介绍和使用 5. qsort的模拟实现 6. 完结散花 悟已往之不谏,知来者犹可追 创作不易,宝子们!如果这篇文章对你们有帮助的话,别忘了给个免…...
智慧公厕:打造智慧城市环境卫生新标杆
随着科技的不断发展和城市化进程的加速推进,智慧城市建设已经成为各地政府和企业关注的焦点。而作为智慧城市环境卫生管理的基础设施,智慧公厕的建设和发展也备受重视,被誉为智慧城市的新标杆。本文以智慧公厕源头厂家广州中期科技有限公司&a…...
【学习版】Microsoft Office 2021安装破解教程
本文转载自知乎:https://zhuanlan.zhihu.com/p/655653158 由本人二次整理修改 用到的软件为:Office Tool Plus,下载链接:Office Tool Plus 官方网站 - 一键部署 Office (landian.vip) 下载页面:(随机找个站…...
基于java Springboot实现课程评分系统设计和实现
基于java Springboot实现课程评分系统设计和实现 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源…...
git操作基本指令
1.查看用户名 git config user.name 2.查看密码 git config user.password 3.查看邮箱 git config user.email 4.修改用户名 git config --global user.name "xxx(新用户名)" 5.修改密码 git config --global user.password "xxx(新密码)" 6.修改…...
YOLO算法
YOLO介绍 YOLO,全称为You Only Look Once: Unified, Real-Time Object Detection,是一种实时目标检测算法。目标检测是计算机视觉领域的一个重要任务,它不仅需要识别图像中的物体类别,还需要确定它们的位置。与分类任务只关注对…...
【Android】更改手机主题导致app数据丢失问题
情景:在使用app过程中更改系统主题(比如从浅色主题改为深色主题),这时activity销毁重建了(即走了onPause、onStop、onSaveInstanceState、onDestroy、onCreate、onRestoreInstanceState、onStart、onResume的生命周期&…...
Dell R730 2U服务器实践3:安装英伟达上代专业AI训练Nvidia P4计算卡
Dell R730是一款非常流行的服务器,2U的机箱可以放入两张显卡,这次先用一张英伟达上代专业级AI训练卡:P4卡做实验,本文记录安装过程。 简洁步骤: 打开机箱将P4显卡插在4号槽位关闭机箱安装驱动 详细步骤: 对…...
Nacos环境搭建 -- 服务注册与发现
为什么需要服务治理 在未引入服务治理模块之前,服务之间的通信是服务间直接发起并调用来实现的。只要知道了对应服务的服务名称、IP地址、端口号,就能够发起服务通信。比如A服务的IP地址为192.168.1.100:9000,B服务直接向该IP地址发起请求就…...
Linux了解
简介 Linux是一种自由和开放源代码的类UNIX操作系统,由芬兰的Linus Torvalds于1991年首次发布。Linux最初是作为支持英特尔x86架构的个人电脑的一个自由操作系统,现在已经被移植到更多的计算机硬件平台,如手机、平板电脑、路由器、视频游戏控…...
Keil新版本安装编译器ARMCompiler 5.06
0x00 缘起 我手头的项目在使用最新版本的编译器后,烧录后无法正常运行,故安装5.06,测试后发现程序运行正常,以下为编译器的安装步骤。 0x01 解决方法 1. 下载编译器安装文件,可以去ARM官网下载,也可以使用我…...
【基础训练 || Test-1】
总言 主要内容:一些习题。 文章目录 总言一、选择1、for循环、操作符(逗号表达式)2、格式化输出(转换说明符)3、for循环、操作符(逗号表达式、赋值和判等)4、if语句、操作符ÿ…...
Python读取hbase数据库
1. hbase连接 首先用hbase shell 命令来进入到hbase数据库,然后用list命令来查看hbase下所有表,以其中表“DB_level0”为例,可以看到库名“baotouyiqi”是拼接的,python代码访问时先连接: def hbase_connection(hbase…...
LeetCode41题:缺失的第一个正数(python3)
这道题写的时候完全没有思路,看了很久的题解,才总结出来。 class Solution:def firstMissingPositive(self, nums: List[int]) -> int:nums_set set(nums)n len(nums)for i in range(1, n 1):if i not in nums_set:return ireturn n 1...
k8s从入门到放弃之Ingress七层负载
k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...
基于Flask实现的医疗保险欺诈识别监测模型
基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施,由雇主和个人按一定比例缴纳保险费,建立社会医疗保险基金,支付雇员医疗费用的一种医疗保险制度, 它是促进社会文明和进步的…...
最新SpringBoot+SpringCloud+Nacos微服务框架分享
文章目录 前言一、服务规划二、架构核心1.cloud的pom2.gateway的异常handler3.gateway的filter4、admin的pom5、admin的登录核心 三、code-helper分享总结 前言 最近有个活蛮赶的,根据Excel列的需求预估的工时直接打骨折,不要问我为什么,主要…...
什么?连接服务器也能可视化显示界面?:基于X11 Forwarding + CentOS + MobaXterm实战指南
文章目录 什么是X11?环境准备实战步骤1️⃣ 服务器端配置(CentOS)2️⃣ 客户端配置(MobaXterm)3️⃣ 验证X11 Forwarding4️⃣ 运行自定义GUI程序(Python示例)5️⃣ 成功效果
最近需要在离线机器上运行软件,所以得把软件用docker打包起来,大部分功能都没问题,出了一个奇怪的事情。同样的代码,在本机上用vscode可以运行起来,但是打包之后在docker里出现了问题。使用的是dialog组件,…...
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列,以便知晓哪些列包含有价值的数据,…...
HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
Linux离线(zip方式)安装docker
目录 基础信息操作系统信息docker信息 安装实例安装步骤示例 遇到的问题问题1:修改默认工作路径启动失败问题2 找不到对应组 基础信息 操作系统信息 OS版本:CentOS 7 64位 内核版本:3.10.0 相关命令: uname -rcat /etc/os-rele…...
CVE-2023-25194源码分析与漏洞复现(Kafka JNDI注入)
漏洞概述 漏洞名称:Apache Kafka Connect JNDI注入导致的远程代码执行漏洞 CVE编号:CVE-2023-25194 CVSS评分:8.8 影响版本:Apache Kafka 2.3.0 - 3.3.2 修复版本:≥ 3.4.0 漏洞类型:反序列化导致的远程代…...
PLC入门【4】基本指令2(SET RST)
04 基本指令2 PLC编程第四课基本指令(2) 1、运用上接课所学的基本指令完成个简单的实例编程。 2、学习SET--置位指令 3、RST--复位指令 打开软件(FX-TRN-BEG-C),从 文件 - 主画面,“B: 让我们学习基本的”- “B-3.控制优先程序”。 点击“梯形图编辑”…...
