神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频:
2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili
2.总结:
(1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv+2个maxpool+3个linear层
(2)up主整理的train.py等内容里面的细节分析值得学习
(3)对于预测代码的撰写,可以参考代码的predict.py文件
3.几个文件的源代码我都贴一下(都不多——但很精):
(1)首先是 model.py:
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x)) # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x) # output(16, 14, 14)x = F.relu(self.conv2(x)) # output(32, 10, 10)x = self.pool2(x) # output(32, 5, 5)x = x.view(-1, 32*5*5) # output(32*5*5)x = F.relu(self.fc1(x)) # output(120)x = F.relu(self.fc2(x)) # output(84)x = self.fc3(x) # output(10)return x
模型 == 2个conv + 2个max_pool + 3个linear
(2) train.py训练模型的文件:
import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# 定义transform的数据增强transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 处理cifar10的 train和val的数据集的问题# 50000张训练图片# 第一次使用时要将download设置为True才会自动去下载数据集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000张验证图片# 第一次使用时要将download设置为True才会自动去下载数据集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',# 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 训练前的准备: 实例化model网络net , 定义 loss函数 CrossEntropyLoss() 和 Adam优化器net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 开始训练:zero_grad() + outputs + loss backward + optim stepfor epoch in range(5): # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499: # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image) # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 最后把 model的 参数save 为一个.pth文件save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()
分析:数据集划分 + 实例化网络_优化器_loss函数 + 分epoch开始寻 + save_pth权重
(3)predict.py:
import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():# 将需要检测图像 裁剪为32*32transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#实例化网络 + 才入权重net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))# 打开图像,转换格式im = Image.open('1.jpg')im = transform(im) # [C, H, W]im = torch.unsqueeze(im, dim=0) # [N, C, H, W]# 输入到网络中, 得到预测的结果with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()
predict == 处理图像 + 实例化权重 + 得到预测结果
相关文章:
神经网络基础知识:LeNet的搭建-训练-预测
1.参考视频: 2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili 2.总结: (1)LeNet网络就是 我最开始用来预测mnist数据集的那个网络,简单的2个conv2个maxpool3个linear层 (2)up主整理的train.py…...

SpringMVC 学习(七)之报文信息转换器 HttpMessageConverter
目录 1 HttpMessageConverter 介绍 2 RequestBody 注解 3 ResponseBody 注解 4 RequestEntity 5 ResponseEntity 6 RestController 注解 1 HttpMessageConverter 介绍 HttpMessageConverter 报文信息转换器,将请求报文(如JSON、XML、HTML等&#x…...

浅谈密码学
文章目录 每日一句正能量前言什么是密码学对称加密简述加密语法Kerckhoffs原则常用的加密算法现代密码学的原则威胁模型(按强度增加的顺序) 密码学的应用领域后记 每日一句正能量 人生在世,谁也不能做到让任何人都喜欢,所以没必要…...

Android 混淆是啥玩意儿?
什么是混淆 Android混淆,是伴随着Android系统的流行而产生的一种Android APP保护技术,用于保护APP不被破解和逆向分析。简单的说,就是将原本正常的项目文件,对其类、方法、字段,重新命名a,b,c…之类的字母,…...

【嵌入式——QT】QListWidget
QListWidget类提供了一个基于项的列表小部件,QListWidgetItem是列表中的项,该篇文章中涉及到的功能有添加列表项,插入列表项,删除列表项,清空列表,向上移动列表项,向下移动列表项。 常用API a…...

爬虫入门到精通_基础篇5(PyQuery库_PyQuery说明,初始化,基本CSS选择器,查找元素,遍历,获取信息,DOM操作)
1 PyQuery说明: PyQuery是python中一个强大而又灵活的网页解析库,如果你觉得正则写起来太麻烦,又觉得BeautifulSoup语法太难记,如果你熟悉jQuery的语法那么,PyQuery就是你绝佳的选择。 安装 pip3 install pyquery2 …...

用冒泡排序模拟C语言中的内置快排函数qsort!
目录 编辑 1.回调函数的介绍 2. 回调函数实现转移表 3. 冒泡排序的实现 4. qsort的介绍和使用 5. qsort的模拟实现 6. 完结散花 悟已往之不谏,知来者犹可追 创作不易,宝子们!如果这篇文章对你们有帮助的话,别忘了给个免…...

智慧公厕:打造智慧城市环境卫生新标杆
随着科技的不断发展和城市化进程的加速推进,智慧城市建设已经成为各地政府和企业关注的焦点。而作为智慧城市环境卫生管理的基础设施,智慧公厕的建设和发展也备受重视,被誉为智慧城市的新标杆。本文以智慧公厕源头厂家广州中期科技有限公司&a…...

【学习版】Microsoft Office 2021安装破解教程
本文转载自知乎:https://zhuanlan.zhihu.com/p/655653158 由本人二次整理修改 用到的软件为:Office Tool Plus,下载链接:Office Tool Plus 官方网站 - 一键部署 Office (landian.vip) 下载页面:(随机找个站…...

基于java Springboot实现课程评分系统设计和实现
基于java Springboot实现课程评分系统设计和实现 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获取源…...
git操作基本指令
1.查看用户名 git config user.name 2.查看密码 git config user.password 3.查看邮箱 git config user.email 4.修改用户名 git config --global user.name "xxx(新用户名)" 5.修改密码 git config --global user.password "xxx(新密码)" 6.修改…...

YOLO算法
YOLO介绍 YOLO,全称为You Only Look Once: Unified, Real-Time Object Detection,是一种实时目标检测算法。目标检测是计算机视觉领域的一个重要任务,它不仅需要识别图像中的物体类别,还需要确定它们的位置。与分类任务只关注对…...
【Android】更改手机主题导致app数据丢失问题
情景:在使用app过程中更改系统主题(比如从浅色主题改为深色主题),这时activity销毁重建了(即走了onPause、onStop、onSaveInstanceState、onDestroy、onCreate、onRestoreInstanceState、onStart、onResume的生命周期&…...

Dell R730 2U服务器实践3:安装英伟达上代专业AI训练Nvidia P4计算卡
Dell R730是一款非常流行的服务器,2U的机箱可以放入两张显卡,这次先用一张英伟达上代专业级AI训练卡:P4卡做实验,本文记录安装过程。 简洁步骤: 打开机箱将P4显卡插在4号槽位关闭机箱安装驱动 详细步骤: 对…...

Nacos环境搭建 -- 服务注册与发现
为什么需要服务治理 在未引入服务治理模块之前,服务之间的通信是服务间直接发起并调用来实现的。只要知道了对应服务的服务名称、IP地址、端口号,就能够发起服务通信。比如A服务的IP地址为192.168.1.100:9000,B服务直接向该IP地址发起请求就…...
Linux了解
简介 Linux是一种自由和开放源代码的类UNIX操作系统,由芬兰的Linus Torvalds于1991年首次发布。Linux最初是作为支持英特尔x86架构的个人电脑的一个自由操作系统,现在已经被移植到更多的计算机硬件平台,如手机、平板电脑、路由器、视频游戏控…...

Keil新版本安装编译器ARMCompiler 5.06
0x00 缘起 我手头的项目在使用最新版本的编译器后,烧录后无法正常运行,故安装5.06,测试后发现程序运行正常,以下为编译器的安装步骤。 0x01 解决方法 1. 下载编译器安装文件,可以去ARM官网下载,也可以使用我…...

【基础训练 || Test-1】
总言 主要内容:一些习题。 文章目录 总言一、选择1、for循环、操作符(逗号表达式)2、格式化输出(转换说明符)3、for循环、操作符(逗号表达式、赋值和判等)4、if语句、操作符ÿ…...

Python读取hbase数据库
1. hbase连接 首先用hbase shell 命令来进入到hbase数据库,然后用list命令来查看hbase下所有表,以其中表“DB_level0”为例,可以看到库名“baotouyiqi”是拼接的,python代码访问时先连接: def hbase_connection(hbase…...

LeetCode41题:缺失的第一个正数(python3)
这道题写的时候完全没有思路,看了很久的题解,才总结出来。 class Solution:def firstMissingPositive(self, nums: List[int]) -> int:nums_set set(nums)n len(nums)for i in range(1, n 1):if i not in nums_set:return ireturn n 1...
Python爬虫实战:研究MechanicalSoup库相关技术
一、MechanicalSoup 库概述 1.1 库简介 MechanicalSoup 是一个 Python 库,专为自动化交互网站而设计。它结合了 requests 的 HTTP 请求能力和 BeautifulSoup 的 HTML 解析能力,提供了直观的 API,让我们可以像人类用户一样浏览网页、填写表单和提交请求。 1.2 主要功能特点…...
ES6从入门到精通:前言
ES6简介 ES6(ECMAScript 2015)是JavaScript语言的重大更新,引入了许多新特性,包括语法糖、新数据类型、模块化支持等,显著提升了开发效率和代码可维护性。 核心知识点概览 变量声明 let 和 const 取代 var…...
FFmpeg 低延迟同屏方案
引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...
【Java学习笔记】Arrays类
Arrays 类 1. 导入包:import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序(自然排序和定制排序)Arrays.binarySearch()通过二分搜索法进行查找(前提:数组是…...

使用分级同态加密防御梯度泄漏
抽象 联邦学习 (FL) 支持跨分布式客户端进行协作模型训练,而无需共享原始数据,这使其成为在互联和自动驾驶汽车 (CAV) 等领域保护隐私的机器学习的一种很有前途的方法。然而,最近的研究表明&…...

【JavaWeb】Docker项目部署
引言 之前学习了Linux操作系统的常见命令,在Linux上安装软件,以及如何在Linux上部署一个单体项目,大多数同学都会有相同的感受,那就是麻烦。 核心体现在三点: 命令太多了,记不住 软件安装包名字复杂&…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3
一,概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本:2014.07; Kernel版本:Linux-3.10; 二,Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01),并让boo…...

SAP学习笔记 - 开发26 - 前端Fiori开发 OData V2 和 V4 的差异 (Deepseek整理)
上一章用到了V2 的概念,其实 Fiori当中还有 V4,咱们这一章来总结一下 V2 和 V4。 SAP学习笔记 - 开发25 - 前端Fiori开发 Remote OData Service(使用远端Odata服务),代理中间件(ui5-middleware-simpleproxy)-CSDN博客…...

视频行为标注工具BehaviLabel(源码+使用介绍+Windows.Exe版本)
前言: 最近在做行为检测相关的模型,用的是时空图卷积网络(STGCN),但原有kinetic-400数据集数据质量较低,需要进行细粒度的标注,同时粗略搜了下已有开源工具基本都集中于图像分割这块,…...

面向无人机海岸带生态系统监测的语义分割基准数据集
描述:海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而,目前该领域仍面临一个挑战,即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...