当前位置: 首页 > news >正文

深度学习---卷积神经网络

卷积神经网络概述

卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领域。往往输入的图像都很大,使用全连接网络的话,计算的代价较高。另外图像也很难保留原有的特征,导致图像处理的准确率不高。

卷积神经网络(Convolutional Neural Network)是含有卷积层的神经网络。卷积层的作用就是用来自动学习、提取图像的特征。

CNN网络主要有三部分构成:卷积层、池化层和全连接层构成,其中卷积层负责提取图像中的局部特征;池化层用来大幅降低参数量级(降维);全连接层类似人工神经网络的部分,用来输出想要的结果。

图像概述

图像是由像素点组成的,每个像素点的值范围为:[0,255],像素值越大意味着较亮。比如一张 200x200 的图像,则是由 40000 个像素点组成,如果每个像素点都是 0 的话,意味着这是一张全黑的图像。

彩色图一般都是多通道的图像,所谓多通道可以理解为图像由多个不同的图像层叠加而成,例如看到的彩色图像一般都是由 RGB 三个通道组成的,还有一些图像具有 RGBA 四个通道,最后一个通道为透明通道,该值越小,则图像越透明。

卷积层

卷积计算

Padding

Stride

多通道卷积计算

多卷积核卷积计算

特征图大小

池化层

池化层 (Pooling) 降低维度,缩减模型大小,提高计算速度。即:主要对卷积层学习到的特征图进行下采样(SubSampling)处理。池化层主要有两种:最大池化、平均池化。

池化层计算

Stride

Padding

多通道池化计算

案例-图像分类

CIFAR10 数据集

from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader# 1. 数据集基本信息
def test01():# 加载数据集train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))valid = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))# 数据集数量print('训练集数量:', len(train.targets))print('测试集数量:', len(valid.targets))# 数据集形状print("数据集形状:", train[0][0].shape)# 数据集类别print("数据集类别:", train.class_to_idx)# 2. 数据加载器
def test02():train = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))dataloader = DataLoader(train, batch_size=8, shuffle=True)for x, y in dataloader:print(x.shape)print(y)breakif __name__ == '__main__':test01()test02()

搭建图像分类网络

class ImageClassification(nn.Module):def __init__(self):super(ImageClassification, self).__init__()self.conv1 = nn.Conv2d(3, 6, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.linear1 = nn.Linear(576, 120)self.linear2 = nn.Linear(120, 84)self.out = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)# 由于最后一个批次可能不够 32,所以需要根据批次数量来 flattenx = x.reshape(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))return self.out(x)

编写训练函数

使用多分类交叉熵损失函数,Adam 优化器:

def train():# 加载 CIFAR10 训练集, 并将其转换为张量transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transgform)# 构建图像分类模型model = ImageClassification()# 构建损失函数criterion = nn.CrossEntropyLoss()# 构建优化方法optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练轮数epoch = 100for epoch_idx in range(epoch):# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 样本数量sam_num = 0# 损失总和total_loss = 0.0# 开始时间start = time.time()correct = 0for x, y in dataloader:# 送入模型output = model(x)# 计算损失loss = criterion(output, y)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 参数更新optimizer.step()correct += (torch.argmax(output, dim=-1) == y).sum()total_loss += (loss.item() * len(y))sam_num += len(y)print('epoch:%2s loss:%.5f acc:%.2f time:%.2fs' %(epoch_idx + 1,total_loss / sam_num,correct / sam_num,time.time() - start))# 序列化模型torch.save(model.state_dict(), 'model/image_classification.bin')

编写预测函数

def test():# 加载 CIFAR10 测试集, 并将其转换为张量transgform = Compose([ToTensor()])cifar10 = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transgform)# 构建数据加载器dataloader = DataLoader(cifar10, batch_size=BATCH_SIZE, shuffle=True)# 加载模型model = ImageClassification()model.load_state_dict(torch.load('model/image_classification.bin'))model.eval()total_correct = 0total_samples = 0for x, y in dataloader:output = model(x)total_correct += (torch.argmax(output, dim=-1) == y).sum()total_samples += len(y)print('Acc: %.2f' % (total_correct / total_samples))

总结

可以从以下几个方面来调整网络:

  1. 增加卷积核输出通道数
  2. 增加全连接层的参数量
  3. 调整学习率
  4. 调整优化方法
  5. 修改激活函数
  6. 等等...

把学习率由 1e-3 修改为 1e-4,并网络参数量增加如下代码所示:

class ImageClassification(nn.Module):def __init__(self):super(ImageClassification, self).__init__()self.conv1 = nn.Conv2d(3, 32, stride=1, kernel_size=3)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 128, stride=1, kernel_size=3)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.linear1 = nn.Linear(128 * 6 * 6, 2048)self.linear2 = nn.Linear(2048, 2048)self.out = nn.Linear(2048, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)# 由于最后一个批次可能不够 32,所以需要根据批次数量来 flattenx = x.reshape(x.size(0), -1)x = F.relu(self.linear1(x))x = F.dropout(x, p=0.5)x = F.relu(self.linear2(x))x = F.dropout(x, p=0.5)return self.out(x)

相关文章:

深度学习---卷积神经网络

卷积神经网络概述 卷积神经网络是深度学习在计算机视觉领域的突破性成果。在计算机视觉领域。往往输入的图像都很大,使用全连接网络的话,计算的代价较高。另外图像也很难保留原有的特征,导致图像处理的准确率不高。 卷积神经网络&#xff0…...

Windows系统下安装CouchDB3.3.2教程

安装 前往CouchDB官网 官网点击download下载msi文件 双击该msi文件,一直下一步 创建个人account 设置cookie value 用于进行身份验证和授权。 愉快下载 点击OK 重启 启动 重启电脑后 打开浏览器并访问以下链接:http://127.0.0.1:5984/ 如果没有问…...

JavaScript基础知识(二)

JavaScript基础知识(二) 一、ES2015 基础语法1.变量2.常量3.模板字符串4.结构赋值 二、函数进阶1. 设置默认参数值2. 立即执行函数3. 闭包4. 箭头函数 三、面向对象1. 面向对象概述2. 基本概念3. 新语法 与 旧语法3.1 ES5 面向对象的知识ES5构造函数原型…...

SQL NULL Values(空值)

什么是SQL NULL值? SQL 中,NULL 用于表示缺失的值。数据表中的 NULL 值表示该值所处的字段为空。 具有NULL值的字段是没有值的字段。 如果表中的字段是可选的,则可以插入新记录或更新记录而不向该字段添加值。然后,该字段将被保存…...

云原生Docker网络管理

目录 Docker网络 Docker 网络实现原理 为容器创建端口映射 查看容器的输出和日志信息 Docker 的网络模式 查看docker网络列表 指定容器网络模式 网络模式详解 host模式 container模式 none模式 bridge模式 自定义网络 Docker网络 Docker 网络实现原理 Docker使用Lin…...

聊聊线程池的预热

序 本文主要研究一下线程池的预热 prestartCoreThread java/util/concurrent/ThreadPoolExecutor.java /*** Starts a core thread, causing it to idly wait for work. This* overrides the default policy of starting core threads only when* new tasks are executed. T…...

VueComponent的原型对象

一、prototype 每一个构造函数身上又有一个prototype指向其原型对象。 如果我们在控制台输入如下代码,就能看到Vue构造函数的信息,在他身上可以找到prototype属性,指向的是Vue原型对象: 二、__proto__ 通过构造函数创建的实例对…...

Redis不止能存储字符串,还有List、Set、Hash、Zset,用对了能给你带来哪些优势?

文章目录 🌟 Redis五大数据类型的应用场景🍊 一、String🍊 二、Hash🍊 三、List🍊 四、Set🍊 五、Zset 📕我是廖志伟,一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO…...

Python OpenCV通过灰度平均值进行二值化处理以减少像素误差

Python OpenCV通过灰度平均值进行二值化处理以减少像素误差 前言前提条件相关介绍实验环境通过灰度平均值进行二值化处理以减少像素误差固定阈值二值化代码实现 灰度平均值二值化代码实现 前言 由于本人水平有限,难免出现错漏,敬请批评改正。更多精彩内容…...

[Golang]多返回值函数、defer关键字、内置函数、变参函数、类成员函数、匿名函数

函数 文章目录 函数多返回值函数按值传递、按引用传递类成员函数改变外部变量变参函数defer和追踪说明一些常见操作实现 使用defer实现代码追踪记录函数的参数和返回值 常见的内置函数将函数作为参数闭包实例闭包将函数作为返回值 计算函数执行时间使用内存缓存来提升性能 参考…...

【剑指Offer】:删除链表中的倒数第N个节点(此题是LeetCode上面的)剑指Offer上面是链表中的倒数第K个节点

给定一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点 示例 1: 输入:head [1,2,3,4,5], n 2 输出:[1,2,3,5] 示例 2: 输入:head [1], n 1 输出:[] 示例 3:…...

acwing第 126 场周赛 (扩展字符串)

5281. 扩展字符串 一、题目要求 某字符串序列 s0,s1,s2,… 的生成规律如下: s0 DKER EPH VOS GOLNJ ER RKH HNG OI RKH UOPMGB CPH VOS FSQVB DLMM VOS QETH SQBsnDKER EPH VOS GOLNJ UKLMH QHNGLNJ Asn−1AB CPH VOS FSQVB DLMM VOS QHNG Asn−1AB,其…...

Milvus 介绍

Milvus 介绍 Milvus 矢量数据库是什么?关键概念非结构化数据嵌入向量向量相似度搜索 为什么是 Milvus?支持哪些索引和指标?索引类型相似度指标(Similarity metrics) 应用示例Milvus 是如何设计的?开发者工具API访问Milvus 生态系统工具 本页…...

Linux绝对路径和相对路径

在 Linux 中,简单的理解一个文件的路径,指的就是该文件存放的位置。 只要我们告诉 Linux 系统某个文件存放的准确位置,那么它就可以找到这个文件。指明一个文件存放的位置,有 2 种方法,分别是使用绝对路径和相对路径。…...

Linux:firewalld防火墙-基础使用(2)

上一章 Linux:firewalld防火墙-介绍(1)-CSDN博客https://blog.csdn.net/w14768855/article/details/133960695?spm1001.2014.3001.5501 我使用的系统为centos7 firewalld启动停止等操作 systemctl start firewalld 开启防火墙 systemct…...

【每日一练】20231023

统计每个字符出现的次数相关问题 方法一&#xff1a;map的put方法遍历 public class Test {public static void main(String[] args) {StringBuilder sb new StringBuilder("");Random ran new Random();for(int i0;i<2000000;i) {sb.append((char) (a ran.n…...

【项目经理】工作流引擎

项目经理之 工作流引擎 一、业务系统管理目的维护信息 二、组织架构管理目的维护信息 三、角色矩阵管理目的维护信息 四、条件变量管理目的维护信息 五、流程模型管理目的维护信息 六、流程版本管理目的维护信息 七、流程监管控制目的维护信息 系列文章版本记录 一、业务系统管…...

025-第三代软件开发-实现需求长时间未操作返回登录界面

第三代软件开发-实现需求长时间未操作返回登录界面 文章目录 第三代软件开发-实现需求长时间未操作返回登录界面项目介绍实现需求长时间未操作返回登录界面实现思路用户操作监控QML 逻辑处理 关键字&#xff1a; Qt、 Qml、 QTimer、 timeout、 eventFilter 项目介绍 欢迎…...

驱动开发LED灯绑定设备文件

头文件 #ifndef __HEAD_H__ #define __HEAD_H__typedef struct {unsigned int MODER;unsigned int OTYPER;unsigned int OSPEEDR;unsigned int PUPDR;unsigned int IDR;unsigned int ODR; }gpio_t;#define PHY_LED1_ADDR 0x50006000 #define PHY_LED2_ADDR 0x50007000 #defin…...

MySql 数据库基础概念,基本简单操作及数据类型介绍

文章目录 数据库基础为什么需要数据库&#xff1f;创建数据库mysql架构SQL语句分类编码集修改数据库属性数据库备份 表的基本操作存在时更新&#xff0c;不存在时插入 数据类型日期类型enum和set 数据库基础 以特定的格式保存文件&#xff0c;叫做数据库&#xff0c;这是狭义上…...

web vue 项目 Docker化部署

Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段&#xff1a; 构建阶段&#xff08;Build Stage&#xff09;&#xff1a…...

51c自动驾驶~合集58

我自己的原文哦~ https://blog.51cto.com/whaosoft/13967107 #CCA-Attention 全局池化局部保留&#xff0c;CCA-Attention为LLM长文本建模带来突破性进展 琶洲实验室、华南理工大学联合推出关键上下文感知注意力机制&#xff08;CCA-Attention&#xff09;&#xff0c;…...

【人工智能】神经网络的优化器optimizer(二):Adagrad自适应学习率优化器

一.自适应梯度算法Adagrad概述 Adagrad&#xff08;Adaptive Gradient Algorithm&#xff09;是一种自适应学习率的优化算法&#xff0c;由Duchi等人在2011年提出。其核心思想是针对不同参数自动调整学习率&#xff0c;适合处理稀疏数据和不同参数梯度差异较大的场景。Adagrad通…...

关于nvm与node.js

1 安装nvm 安装过程中手动修改 nvm的安装路径&#xff0c; 以及修改 通过nvm安装node后正在使用的node的存放目录【这句话可能难以理解&#xff0c;但接着往下看你就了然了】 2 修改nvm中settings.txt文件配置 nvm安装成功后&#xff0c;通常在该文件中会出现以下配置&…...

基础测试工具使用经验

背景 vtune&#xff0c;perf, nsight system等基础测试工具&#xff0c;都是用过的&#xff0c;但是没有记录&#xff0c;都逐渐忘了。所以写这篇博客总结记录一下&#xff0c;只要以后发现新的用法&#xff0c;就记得来编辑补充一下 perf 比较基础的用法&#xff1a; 先改这…...

Robots.txt 文件

什么是robots.txt&#xff1f; robots.txt 是一个位于网站根目录下的文本文件&#xff08;如&#xff1a;https://example.com/robots.txt&#xff09;&#xff0c;它用于指导网络爬虫&#xff08;如搜索引擎的蜘蛛程序&#xff09;如何抓取该网站的内容。这个文件遵循 Robots…...

Linux-07 ubuntu 的 chrome 启动不了

文章目录 问题原因解决步骤一、卸载旧版chrome二、重新安装chorme三、启动不了&#xff0c;报错如下四、启动不了&#xff0c;解决如下 总结 问题原因 在应用中可以看到chrome&#xff0c;但是打不开(说明&#xff1a;原来的ubuntu系统出问题了&#xff0c;这个是备用的硬盘&a…...

初学 pytest 记录

安装 pip install pytest用例可以是函数也可以是类中的方法 def test_func():print()class TestAdd: # def __init__(self): 在 pytest 中不可以使用__init__方法 # self.cc 12345 pytest.mark.api def test_str(self):res add(1, 2)assert res 12def test_int(self):r…...

Java 二维码

Java 二维码 **技术&#xff1a;**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...

10-Oracle 23 ai Vector Search 概述和参数

一、Oracle AI Vector Search 概述 企业和个人都在尝试各种AI&#xff0c;使用客户端或是内部自己搭建集成大模型的终端&#xff0c;加速与大型语言模型&#xff08;LLM&#xff09;的结合&#xff0c;同时使用检索增强生成&#xff08;Retrieval Augmented Generation &#…...