当前位置: 首页 > news >正文

(三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
    • Q1:卷积网络和传统网络的区别
    • Q2:卷积神经网络的架构
    • Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在
    • 4、 具体的实现代码+网络搭建


前言

深度学习pytorch系列第三篇啦,之前更了FC,NN,这篇是卷积神经网络(cNN)模型实现手写数字识别,依然是重在理解哈,具体的理解内容我都以注释的形式放在了代码中,我就直接放代码了,因为我把一些知识点和理解的东西用注释的形式写了


首先是关于卷积神经网络的一些点

Q1:卷积网络和传统网络的区别

传统网络只适合结构化数据,不适合图像数据,由于图像数据的数据量大(表现为像素点多),传统网络需要使用的参数量太大

Q2:卷积神经网络的架构

卷积神经网络包括:输入层,卷积层,池化层,全连接层
重点介绍卷积层!!
卷积就是针对每个区域去计算特征。可以这样做的原因是:图片是有像素点构成的,针对每个像素点进行处理,需要的参数量过于庞大,并且相邻的像素点之间是存在联系的
特征图的个数与卷积核的个数一致。每个卷积核通过对输入特征图进行卷积操作,生成一个输出特征图。因此,卷积核的个数决定了输出的特征图的个数。
使用不同的卷积核学习同一个位置,可以得到不同的特征图,从而使特征多样化
卷积核的大小一般使用3*3
卷积核的大小规格一般是固定的,卷积核的数量理论上是越多越好
卷积层涉及的参数有:滑动窗口步长,卷积核尺寸,边缘填充,卷积核个数
卷积结果计算公式:长:h2=(h1-Fh+2p)/s +1 宽:w2=(w1-Fw+2p)/s +1
其中:w1,h1表示输入的宽度,长度;w2和h2表示输出特征图的宽度、长度,F表示卷积核的长和宽,s表示滑动窗口的补偿,p表示边界填充
经过卷积操作后,特征图的长和宽也可以保持不变
池化层的作用就是筛选好的特征,pool是只筛选位置的,channel是全部使用的
池化也称为下采样,(一次只能下采样原来的一半,不能直接224-16)
卷积神经网络由多个block组成,重点就在于怎么设计这个block的组成
关于卷积神经网络的层数,带权重参数的就算是一层,6个conn+1个fc,就可以说是7层网络结构

Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在

同一个卷积核在各个位置上的参数都是一致的
权重参数的个数与输入数据的大小无关

4、 具体的实现代码+网络搭建

# 读取数据
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
# transforms  进行预处理,比如进行tensor转换
import matplotlib.pyplot as plt
import numpy as np
#全连接:batch*28*28,全连接各个像素点之间无关
# cnn:batch*1*28*28  ,多了一个参数channel,卷积会综合考虑一个窗口之间的关系,因此各个像素点并不是独立的,卷积网络更适合处理图像数据
# 定义超参数
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片
# 训练集
train_dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
# 测试集
test_dataset = datasets.MNIST(root='./data',train=False,transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务
# 定义一个网络
class CNN(nn.Module):def __init__(self):#         构造函数# 卷积网络一般是组合进行的:conv pool relu可以当一个组合super(CNN, self).__init__()self.conv1 = nn.Sequential(  # 输入大小 (1, 28, 28)nn.Conv2d(  # 2d卷积做任务in_channels=1,  # 灰度图out_channels=16,  # 要得到几多少个特征图,就是卷积核的个数,相当于有16个卷积核kernel_size=5,  # 卷积核大小 5*5的stride=1,  # 步长padding=2,  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1,一般是这么希望的#                                             如果不能整除pytorch采用向下取整),  # 输出的特征图为 (16, 28, 28)nn.ReLU(),  # relu层nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14),一般是pooling后是之前的一半)self.conv2 = nn.Sequential(  # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),  # 输出 (32, 14, 14)nn.ReLU(),  # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),  # 输出 (32, 7, 7))self.conv3 = nn.Sequential(  # 下一个套餐的输入 (32, 7, 7)nn.Conv2d(32, 64, 5, 1, 2),  # 输出 (64, 7, 7)nn.ReLU(),  # 输出 (64, 7, 7))# 只有pool的时候才会筛选特征self.out = nn.Linear(64 * 7 * 7, 10)  # 全连接层得到的结果,最后的任务是10分类任务,进行一个wx+b的操作去做分类def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # flatten操作,结果为:(batch_size, 32 * 7 * 7),和reshape操作一样# reshape操作:总的大小是不变的,提供一个维度后,后边的维度自动计算# 比如当前的x:64*7*7,x.size:64,也就是要从三维转成两维,总的大小不变,就变为64*49这样,-1可以简单的看成一个占位符号# 变换维度,开始是64*7*7,转成batchsize*特征个数,比如64*49output = self.out(x)return output
# 定义准确率
def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] # 最大值是多少,最大值的索引,只要索引就可以rights = pred.eq(labels.data.view_as(pred)).sum()return rights, len(labels)
# 训练网络模型
# 实例化
net = CNN()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器,学习率是0.001
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器,普通的随机梯度下降算法
# 开始训练循环
for epoch in range(num_epochs):# 当前epoch的结果保存下来train_rights = []for batch_idx, (data, target) in enumerate(train_loader):  # 针对容器中的每一个批进行循环net.train()output = net(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()right = accuracy(output, target)train_rights.append(right)# 每一个batch都进行训练,每一百个batch进行一次评估if batch_idx % 100 == 0:net.eval()val_rights = []for (data, target) in test_loader:output = net(data)right = accuracy(output, target)val_rights.append(right)# 准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader),loss.data,100. * train_r[0].numpy() / train_r[1],100. * val_r[0].numpy() / val_r[1]))

实现结果
在这里插入图片描述

相关文章:

(三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言Q1:卷积网络和传统网络的区别Q2:卷积神经网络的架构Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在4、 具体的实现代码网络搭建…...

【代码】多种调度模式下的光储电站经济性最优 储能容量配置分析matlab/yalmip

程序名称:多种调度模式下的光储电站经济性最优储能容量配置分析 实现平台:matlab-yalmip-cplex/gurobi 代码简介:代码主要做的是一个光储电站经济最优储能容量配置的问题,对光储电站中储能的容量进行优化,以实现经济…...

深度学习今年来经典模型优缺点总结,包括卷积、循环卷积、Transformer、LSTM、GANs等

文章目录 1、卷积神经网络(Convolutional Neural Networks,CNN)1.1 优点1.2 缺点1.3 应用场景1.4 网络图 2、循环神经网络(Recurrent Neural Networks,RNNs)2.1 优点2.2 缺点2.3 应用场景2.4 网络图 3、长短…...

ChatGPT成为“帮凶”:生成虚假数据集支持未知科学假设

ChatGPT 自发布以来,就成为了大家的好帮手,学生党和打工人更是每天都离不开。 然而这次好帮手 ChatGPT 却帮过头了,莫名奇妙的成为了“帮凶”,一位研究人员利用 ChatGPT 创建了虚假的数据集,用来支持未知的科学假设。…...

c#利用Forms.Timer定时检测Tcp连接状态

目的:本地创建客户端连接服务器端,如果连接正常显示连接正常如果连接异常显示连接异常。 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.T…...

空间注意力:改变我们理解图像的方式

空间注意力:改变我们理解图像的方式 欢迎来到深度学习和计算机视觉的新时代,在这里,空间注意力机制正改变着我们理解和处理图像的方式。本文将深入探讨空间注意力的概念,它如何工作,以及为什么它在现代图像处理技术中…...

【模型报错记录】‘PromptForGeneration‘ object has no attribute ‘can_generate‘

通过这个连接中的方法解决: “PromptForGeneration”对象没有属性“can_generate” 期刊 #277 thunlp/OpenPrompt GitHub的 问题描述:在使用model.generate() 的时候报错:PromptForGeneration object has no attribute can_generate 解决方法…...

mysql学习记录

关系型数据库:不是把所有的数据全部存储在一起,而是分类存储在一起。 常见的数据库 关系型:oracle大型收费,mysql小型免费。 sql语言(操作数据库) structured query language 结构化查询语言 1.DDL 数据定义语言 创建数…...

Hdoop学习笔记(HDP)-Part.11 安装Kerberos

目录 Part.01 关于HDP Part.02 核心组件原理 Part.03 资源规划 Part.04 基础环境配置 Part.05 Yum源配置 Part.06 安装OracleJDK Part.07 安装MySQL Part.08 部署Ambari集群 Part.09 安装OpenLDAP Part.10 创建集群 Part.11 安装Kerberos Part.12 安装HDFS Part.13 安装Ranger …...

浅谈UML的概念和模型之UML九种图

1、用例图(use case diagrams) 【概念】描述用户需求,从用户的角度描述系统的功能 【描述方式】椭圆表示某个用例;人形符号表示角色 【目的】帮组开发团队以一种可视化的方式理解系统的功能需求 【用例图】 2、静态图 类图&…...

杨志丰:OceanBase助力企业应对数据库转型深水区挑战

11 月 16 日,OceanBase 在北京顺利举办 2023 年度发布会,正式宣布:将持续践行“一体化”产品战略,为关键业务负载打造一体化数据库。OceanBase 产品总经理杨志丰发表了《助力企业应对数据库转型深水区挑战》主题演讲。 以下为演讲…...

版本控制系统Git学习笔记-Git分支操作

文章目录 概述一、Git分支简介1.1 基本概念1.2 创建分支1.3 分支切换1.4 删除分支 二、新建和合并分支2.1 工作流程示意图2.2 新建分支2.3 合并分支2.4 分支示例2.4.1 当前除了主分支,再次创建了两个分支2.4.2 先合并test1分支2.4.3 合并testbranch分支 2.5 解决合并…...

分布式系统中最基础的 CAP 理论及其应用

对于开发或设计分布式系统的架构师、工程师来说,CAP 是必须要掌握的基础理论,CAP 理论可以帮助架构师对系统设计中目标进行取舍,合理地规划系统拆分的维度。下面我们先讲讲分布式系统的特点。 分布式系统的特点 随着移动互联网的快速发展&a…...

计算机视觉(OpenCV+TensorFlow)

计算机视觉(OpenCVTensorFlow) 文章目录 计算机视觉(OpenCVTensorFlow)前言3.图像金字塔3.1 高斯金字塔3.2 拉普拉斯金字塔 4.图像轮廓图像边缘和图像轮廓的区别检测图像绘制边缘 5.轮廓近似外接矩形外接圆 6. 模板匹配6.1 什么是…...

shell语法

概论 shell是我们通过命令行与操作系统沟通的语言 shell脚本可以直接在命令行中执行,也可以将一套逻辑组织成一个文件,方便复用。 DA Terminal中的命令行可以看成是一个“shell脚本在逐行执行”。 1.脚本示例 新建一个test.sh文件,内容如…...

JAXB的XmlAttribute注解

JAXB的XmlAttribute注解,将一个JavaBean属性映射到一个XML属性。 例如,下面的Java代码,将属性currency映射到了XML的属性currency: package com.thb;import jakarta.xml.bind.annotation.XmlAttribute; import jakarta.xml.bind…...

【代码】基于改进差分进化算法的微电网调度研究matlab

程序名称:基于改进差分进化算法的微电网调度研究 实现平台:matlab 代码简介:了进一步提升差分进化算法的优化性能,结合粒子群(PSO)算法的进化机制,提出一种混合多重随机变异粒子差分进化算法(DE-PSO)。所提算法不仅使用粒子群差分变异策略和…...

计算机基础知识63

Django的条件查询&#xff1a;查询函数 exclude exclude&#xff1a;返回不满足条件的数据 res Author.objects.exclude(pk1) print(res) # <QuerySet [<Author: Author object (2)>, <Author: Author object (3)>]> order_by 1、按照 id 升序排序 res …...

springboot虚拟请求——测试

springboot虚拟请求 表现层测试 web环境模拟测试 虚拟请求状态匹配——执行状态的匹配 Testvoid testStatus(Autowired MockMvc mvc) throws Exception { // //http://localhost:8080/books// 创建一个虚拟请求&#xff0c;当前访问的是booksMockHttpServletRequestBui…...

计算机视觉各个方向概述

计算机视觉发展很长时间了&#xff0c;由传统的计算机视觉到现在如火如荼的计算机视觉多模态&#xff0c;有很多的方向&#xff0c;每一个方向都是一个研究门类&#xff0c;有些已经比较成熟&#xff0c;有些还处于一个开始的阶段&#xff0c;相对于文本语言的处理&#xff0c;…...

AIGC: 关于ChatGPT中API接口调用相关准备工作

ChatGPT之API接口相关 通过页面和GPT交流获取信息相比直接调用GPT的API而言是非常有限的 页面上的GPT是比较封闭的&#xff0c;而且只允许我们去输入文本的信息 我们需要借助GPT的API开发来激发AI工具的无限可能&#xff0c;实现更多个性化需求 1 &#xff09;使用API 使用A…...

【Java Web学习笔记】 1 - HTML入门

项目代码 https://github.com/yinhai1114/JavaWeb_LearningCode/tree/main/html 零、网页的组成 HTML是网页内容的载体。内容就是网页制作者放在页面上想要让用户浏览的信息&#xff0c;可以包含文字、图片视频等。 CSS样式是表现。就像网页的外衣。比如&#xff0c;标题字体、…...

基于windows系统使用Python对于pc当前的所有窗口的相关操作接口

对于windows系统的电脑使用Python可以对其当前的窗口进行宏观的查询等操作 派生博客1:python对pc的窗口进行操作(windows) 派生博客2python获取当前pc的分辨率(windows) 派生博客3使用uiautomation模块来对基于windows系统的pc中的前端界面进行自动化测试(查找控件&#xff…...

30秒搞定一个属于你的问答机器人,快速抓取网站内容

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版&#xff0c;欢迎购买。点击进入详情 文章目录 简介运行效果GitHub地址 简介 爬取一个网站的内容&#xff0c;然后让这个内容变成你自己的私有知识库&#xff0c;并且还可以搭建一个基于私有知识库的问…...

JPA数据源Oracle异常记录

代码执行异常 ObjectOptimisticLockingFailureException org.springframework.orm.ObjectOptimisticLockingFailureException: Batch update returned unexpected row count from update [0]; actual row count: 0; expected: 1; nested exception is org.hibernate.StaleSta…...

抽奖送平板是骗局!!!

在街上被派传单&#xff0c;然后扫了码抽奖中了平板&#xff0c;被领到卖电器门店兑奖。他们给我在宜嘉商城上充值4980&#xff0c;我现场给他们付了4980元&#xff0c;签了他们的业务办理单&#xff0c;上面有违约者赔款30%违约金字样。我领走了荣耀畅玩40plus手机一台。第二天…...

json.decoder.JSONDecodeError: Extra data: line 1 column 332 (char 331)

项目场景&#xff1a; 提示&#xff1a;扩充数据集时&#xff0c;同步修改json标签中的"imagePath"字段的值&#xff0c;出现json文件读写不一致问题。 采用open函数读写模式修改json文件字段。open(jsonF.json, r)。 问题描述 运行修改json文件报错&#xff1a;j…...

rust持续学习 COW

COW我第一次看见还以为是奶牛 很奇怪是个啥 后来了解到是clone on write 缩写的&#xff0c;大乌龙啊 这个有两种enum,一种是borrow&#xff0c;一种是own rust中&#xff0c;数据读写经常涉及到所有权 这个borrow&#xff0c;很显然&#xff0c;就是不可变借用了 own就是可以写…...

【计算机网络】14、DHCP

文章目录 一、概述1.1 好处 二、概念2.1 分配 IP2.2 控制租赁时间2.3 DHCP 的其他网络功能2.4 IP地址范围和用户类别2.5 安全 三、DHCP 消息3.1 DHCP discover message3.2 DHCP offers a message 如果没有 DHCP&#xff0c;IT管理者必须手动选出可用的 ip&#xff0c;这太耗时了…...

【FPGA】Verilog:计数器 | 异步计数器 | 同步计数器 | 2位二进制计数器的实现 | 4位十进制计数器的实现

目录 Ⅰ. 实践说明 0x00 计数器(Counter) 0x01 异步计数器(Asynchronous Counter)...