当前位置: 首页 > news >正文

深度学习-卷积神经网络CNN

案例-图像分类

网络结构: 卷积+BN+激活+池化

数据集介绍

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。下图列举了10个类,每一类随机展示了10张图片:

特征图计算

在卷积层和池化层结束后, 将特征图变形成一行n列数据, 计算特征图进行变化, 映射到全连接层时输入层特征为最后一层卷积层经池化后的特征图各维度相乘

具体流程-# Acc: 0.728

# 导包
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose  # Compose: 数据增强(扩充数据集)
import time
import matplotlib.pyplot as plt
​
batch_size = 16
​
​
# 创建数据集
def create_dataset():torch.manual_seed(21)train = CIFAR10(root='data',train=True,transform=Compose([ToTensor()]))test = CIFAR10(root='data',train=False,transform=Compose([ToTensor()]))return train, test
​
​
# 创建模型
class ImgCls(nn.Module):# 定义网络结构def __init__(self):super(ImgCls, self).__init__()# 定义网络层:卷积层+池化层self.conv1 = nn.Conv2d(3, 16, stride=1, kernel_size=3)self.batch_norm_layer1 = nn.BatchNorm2d(num_features=16, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv2 = nn.Conv2d(16, 32, stride=1, kernel_size=3)self.batch_norm_layer2 = nn.BatchNorm2d(num_features=32, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv3 = nn.Conv2d(32, 64, stride=1, kernel_size=3)self.batch_norm_layer3 = nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv4 = nn.Conv2d(64, 128, stride=1, kernel_size=2)self.batch_norm_layer4 = nn.BatchNorm2d(num_features=128, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv5 = nn.Conv2d(128, 256, stride=1, kernel_size=2)self.batch_norm_layer5 = nn.BatchNorm2d(num_features=256, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool5 = nn.MaxPool2d(kernel_size=2, stride=1)
​# 全连接层self.linear1 = nn.Linear(1024, 2048)self.linear2 = nn.Linear(2048, 1024)self.linear3 = nn.Linear(1024, 512)self.linear4 = nn.Linear(512, 256)self.linear5 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)
​# 定义前向传播def forward(self, x):# 第1层: 卷积+BN+激活+池化x = self.conv1(x)x = self.batch_norm_layer1(x)x = torch.rrelu(x)x = self.pool1(x)
​# 第2层: 卷积+BN+激活+池化x = self.conv2(x)x = self.batch_norm_layer2(x)x = torch.rrelu(x)x = self.pool2(x)
​# 第3层: 卷积+BN+激活+池化x = self.conv3(x)x = self.batch_norm_layer3(x)x = torch.rrelu(x)x = self.pool3(x)
​# 第4层: 卷积+BN+激活+池化x = self.conv4(x)x = self.batch_norm_layer4(x)x = torch.rrelu(x)x = self.pool4(x)
​# 第5层: 卷积+BN+激活+池化x = self.conv5(x)x = self.batch_norm_layer5(x)x = torch.rrelu(x)x = self.pool5(x)
​# 将特征图做成以为向量的形式:相当于特征向量x = x.reshape(x.size(0), -1)  # 将3维特征图转化为1维向量(1, n)
​# 全连接层x = torch.rrelu(self.linear1(x))x = torch.rrelu(self.linear2(x))x = torch.rrelu(self.linear3(x))x = torch.rrelu(self.linear4(x))x = torch.rrelu(self.linear5(x))# 返回输出结果return self.out(x)
​
​
# 训练
def train(model, train_dataset, epochs):torch.manual_seed(21)loss = nn.CrossEntropyLoss()opt = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)loss_total = 0iter = 0stat_time = time.time()for x, y in dataloader:output = model(x.to(device))loss_value = loss(output, y.to(device))opt.zero_grad()loss_value.backward()opt.step()loss_total += loss_value.item()iter += 1print(f'epoch:{epoch + 1:4d}, loss:{loss_total / iter:6.4f}, time:{time.time() - stat_time:.2f}s')torch.save(model.state_dict(), 'model/img_cls_model.pth')
​
​
# 测试
def test(valid_dataset, model, batch_size):# 构建数据加载器dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
​# 计算精度total_correct = 0# 遍历每个batch的数据,获取预测结果,计算精度for x, y in dataloader:output = model(x.to(device))y_pred = torch.argmax(output, dim=-1)total_correct += (y_pred == y.to(device)).sum()# 打印精度print(f'Acc: {(total_correct.item() / len(valid_dataset))}')
​
​
if __name__ == '__main__':batch_size = 16device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 获取数据集train_data, test_data = create_dataset()
​# # 查看数据集# print(f'数据集类别: {train_data.class_to_idx}')# print(f'训练集: {train_data.data.shape}')# print(f'验证集: {test_data.data.shape}')# print(f'类别数量: {len(np.unique(train_data.targets))}')# # 展示图像# plt.figure(figsize=(8, 8))# plt.imshow(train_data.data[0])# plt.title(train_data.classes[train_data.targets[0]])# plt.show()
​# 实例化模型model = ImgCls().to(device)
​# 查看网络结构summary(model, (3, 32, 32), device='cuda', batch_size=batch_size)
​# 模型训练train(model, train_data, epochs=60)# 加载训练好的模型参数model.load_state_dict(torch.load('model/img_cls_model.pth'))model.eval()# 模型评估test(test_data, model, batch_size=16)   # Acc: 0.728
​

调整网络结构

第一次调整: 训练50轮, Acc: 0.71

第二次调整: 训练30轮, Acc:0.7351

第三次调整: batch_size=8, epoch=50 => Acc: 0.7644

# 导包
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose  # Compose: 数据增强(扩充数据集)
import time
import matplotlib.pyplot as plt
​
batch_size = 16
​
​
# 创建数据集
def create_dataset():torch.manual_seed(21)train = CIFAR10(root='data',train=True,transform=Compose([ToTensor()]))test = CIFAR10(root='data',train=False,transform=Compose([ToTensor()]))return train, test
​
​
# 创建模型
class ImgCls(nn.Module):# 定义网络结构def __init__(self):super(ImgCls, self).__init__()# 定义网络层:卷积层+池化层self.conv1 = nn.Conv2d(3, 16, stride=1, kernel_size=3, padding=1)self.batch_norm_layer1 = nn.BatchNorm2d(num_features=16, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv2 = nn.Conv2d(16, 32, stride=1, kernel_size=3, padding=1)self.batch_norm_layer2 = nn.BatchNorm2d(num_features=32, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv3 = nn.Conv2d(32, 64, stride=1, kernel_size=3, padding=1)self.batch_norm_layer3 = nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv4 = nn.Conv2d(64, 128, stride=1, kernel_size=3, padding=1)self.batch_norm_layer4 = nn.BatchNorm2d(num_features=128, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool4 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv5 = nn.Conv2d(128, 256, stride=1, kernel_size=3)self.batch_norm_layer5 = nn.BatchNorm2d(num_features=256, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
​# 全连接层self.linear1 = nn.Linear(1024, 2048)self.linear2 = nn.Linear(2048, 1024)self.linear3 = nn.Linear(1024, 512)self.linear4 = nn.Linear(512, 256)self.linear5 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)
​# 定义前向传播def forward(self, x):# 第1层: 卷积+BN+激活+池化x = self.conv1(x)x = self.batch_norm_layer1(x)x = torch.relu(x)x = self.pool1(x)
​# 第2层: 卷积+BN+激活+池化x = self.conv2(x)x = self.batch_norm_layer2(x)x = torch.relu(x)x = self.pool2(x)
​# 第3层: 卷积+BN+激活+池化x = self.conv3(x)x = self.batch_norm_layer3(x)x = torch.relu(x)x = self.pool3(x)
​# 第4层: 卷积+BN+激活+池化x = self.conv4(x)x = self.batch_norm_layer4(x)x = torch.relu(x)x = self.pool4(x)
​# 第5层: 卷积+BN+激活+池化x = self.conv5(x)x = self.batch_norm_layer5(x)x = torch.rrelu(x)x = self.pool5(x)
​# 将特征图做成以为向量的形式:相当于特征向量x = x.reshape(x.size(0), -1)  # 将3维特征图转化为1维向量(1, n)
​# 全连接层x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))x = torch.relu(self.linear3(x))x = torch.relu(self.linear4(x))x = torch.rrelu(self.linear5(x))# 返回输出结果return self.out(x)
​
​
# 训练
def train(model, train_dataset, epochs):torch.manual_seed(21)loss = nn.CrossEntropyLoss()opt = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)loss_total = 0iter = 0stat_time = time.time()for x, y in dataloader:output = model(x.to(device))loss_value = loss(output, y.to(device))opt.zero_grad()loss_value.backward()opt.step()loss_total += loss_value.item()iter += 1print(f'epoch:{epoch + 1:4d}, loss:{loss_total / iter:6.4f}, time:{time.time() - stat_time:.2f}s')torch.save(model.state_dict(), 'model/img_cls_model1.pth')
​
​
# 测试
def test(valid_dataset, model, batch_size):# 构建数据加载器dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
​# 计算精度total_correct = 0# 遍历每个batch的数据,获取预测结果,计算精度for x, y in dataloader:output = model(x.to(device))y_pred = torch.argmax(output, dim=-1)total_correct += (y_pred == y.to(device)).sum()# 打印精度print(f'Acc: {(total_correct.item() / len(valid_dataset))}')
​
​
if __name__ == '__main__':batch_size = 8device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 获取数据集train_data, test_data = create_dataset()
​# # 查看数据集# print(f'数据集类别: {train_data.class_to_idx}')# print(f'训练集: {train_data.data.shape}')# print(f'验证集: {test_data.data.shape}')# print(f'类别数量: {len(np.unique(train_data.targets))}')# # 展示图像# plt.figure(figsize=(8, 8))# plt.imshow(train_data.data[0])# plt.title(train_data.classes[train_data.targets[0]])# plt.show()
​# 实例化模型model = ImgCls().to(device)
​# 查看网络结构summary(model, (3, 32, 32), device='cuda', batch_size=batch_size)
​# 模型训练train(model, train_data, epochs=50)# 加载训练好的模型参数model.load_state_dict(torch.load('model/img_cls_model1.pth', weights_only=True))model.eval()# 模型评估test(test_data, model, batch_size=16)   # Acc: 0.7644
​

相关文章:

深度学习-卷积神经网络CNN

案例-图像分类 网络结构: 卷积BN激活池化 数据集介绍 CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32323。下图列举了10个类,每一类随机展示了10张图片: 特征图计算 在卷积层和池化层结束后, 将特征…...

241114.学习日志——[CSDIY] [Cpp]零基础速成 [03]

​ CSDIY:这是一个非科班学生的努力之路,从今天开始这个系列会长期更新,(最好做到日更),我会慢慢把自己目前对CS的努力逐一上传,帮助那些和我一样有着梦想的玩家取得胜利!&#xff0…...

大模型研究报告 | 2024年中国金融大模型产业发展洞察报告|附34页PDF文件下载

随着生成算法、预训练模型、多模态数据分析等AI技术的聚集融合,AIGC技术的实践效用迎来了行业级大爆发。通用大模型技术的成熟推动了新一轮行业生产力变革,在投入提升与政策扶植的双重作用下,以大模型技术为底座、结合专业化金融能力的金融大…...

数据库SQL——什么是实体-联系模型(E-R模型)?

目录 什么是实体-联系模型? 1.实体集 2.联系集 3.映射基数 一对一(1:1) 一对多(1:n) 多对一(n:1) 多对多(m:n) 全部参与: 4.主码 弱实体集&#xf…...

在 MySQL 8.0 中,SSL 解密失败,在使用 SSL 加密连接时出现了问题

在 MySQL 8.0 中,SSL 解密失败通常指的是在使用 SSL 加密连接时出现了问题,导致无法建立加密通信。这个问题可能由多种原因引起,下面是一些常见的原因和排查方法: 1. 证书配置问题 确保您在 MySQL 配置中使用了正确的 SSL 证书和…...

React Native 全栈开发实战班 - 第四部分:用户界面进阶之动画效果实现

在移动应用中,动画效果 是提升用户体验的重要手段。合理的动画设计可以增强应用的交互性、流畅性和视觉吸引力。React Native 提供了多种实现动画的方式,包括内置的 Animated API、LayoutAnimation 以及第三方库(如 react-native-reanimated&…...

【CICD】GitLab Runner 和执行器(Executor

GitLab Runner 和执行器(Executor)是 GitLab CI/CD 管道中的两个重要组成部分。理解它们之间的关系有助于更好地配置和使用 CI/CD 流水线。runer是gitlab的ci-agent对接gitlab,而执行器是接受runer下发的ci的任务来干活的。也就是说gitrunner…...

实用教程:如何无损修改MP4视频时长

如何在UltraEdit中搜索MP4文件中的“mvhd”关键字 引言 在视频编辑和分析领域,有时我们需要深入到视频文件的底层结构中去。UltraEdit(UE)和UEStudio作为强大的文本编辑器,允许我们以十六进制模式打开和搜索MP4文件。本文将指导…...

mysqldump命令搭配source命令完成数据库迁移备份

mysqldump 命令使用 需保证mysqld在运行中, 这个命令的目的是将数据库导出到文件中,例如 mysqldump -uusername -ppassword database > db.sql 注意该命令不是在MySQL客户端(即MySQL命令行)执行的,而是在系统命…...

生信:TCGA学习(R、RStudio安装与下载、常用语法与常用快捷键)

前置环境 macOS系统,已安装homebrew且会相关命令。 近期在整理草稿区,所以放出该贴。 R语言、RStudio、R包安装 R语言安装 brew install rRStudio安装 官网地址:https://posit.co/download/rstudio-desktop/ R包下载 注意R语言环境自带…...

十三、注解配置SpringMVC

文章目录 1. 创建初始化类,代替web.xml2. 创建SpringConfig配置类,代替spring的配置文件3. 创建WebConfig配置类,代替SpringMVC的配置文件4. 测试功能 1. 创建初始化类,代替web.xml 2. 创建SpringConfig配置类,代替spr…...

为什么海外服务器IP会被封

海外服务器因为免备案而备受用户欢迎,近年来租用海外服务器的用户也越来越多,自然也可能会出现一些问题。 如果服务器IP被封,在该服务器下的所有业务都无法访问,对自己和对用户来说都会有较大的影响。因此,我们应做好相…...

图像处理技术椒盐噪声

椒盐噪声,也称为脉冲噪声,是图像中经常见到的一种噪声。它是一种随机出现的白点或者黑点,可能是亮的区域有黑色像素或是在暗的区域有白色像素(或是两者皆有)。这些白点和黑点会在图像中随机分布,导致图像中…...

[笔记]L6599的极限工作条件考量

0.名词 OTP over tempature protect.OCP over current protectOVP over voltage protectBrownout Protection Undervoltage Protection可能需要考虑hysteresis response.因为要考虑一些高频干扰 1.基本的过流保护逻辑 参考:ST L6599 器件手册 LLC开关电源&#…...

机器学习基础04

目录 1.朴素贝叶斯-分类 1.1贝叶斯分类理论 1.2条件概率 1.3全概率公式 1.4贝叶斯推断 1.5朴素贝叶斯推断 1.6拉普拉斯平滑系数 1.7API 2.决策树-分类 2.1决策树 2.2基于信息增益的决策树建立 2.2.1信息熵 2.2.2信息增益 2.2.3信息增益决策树建立步骤 2.3基于基…...

Ubuntu 20.04 配置开发环境(持续更新)

搜狗输入法不能显示中文 sudo apt install libqt5qml5 libgsettings-qt1 sudo apt install libqt5qml5 libqt5quick5 libqt5quickwidgets5 qml-module-qtquick2 编译环境配置 sudo apt-get update #base tools of ubuntu sudo apt install net-tools gitk tree vim termina…...

Rocky9/Ubuntu使用pip安装python的库mysqlclient失败解决方式

# Rocky9 直接使用pip安装mysqlclient会出现缺少依赖,需要先安装mysql-devel相关依赖。由于rocky9用MariaDB替代了MySQL,所以我们可以通过安装mariadb-devel来安装所需要的依赖。 如果Rocky9已经开启了powertool repo可以直接使用下面命令安装 dnf in…...

探索 HTML 和 CSS 实现的 3D旋转相册

效果演示 这段HTML与CSS代码创建了一个包含10张卡片的3D旋转效果&#xff0c;每张卡片都有自己的边框颜色和图片。通过CSS的3D变换和动画&#xff0c;实现了一个动态的旋转展示效果 HTML <div class"wrapper"><div class"inner" style"-…...

OpenJudge_ 简单英文题_04:0/1 Knapsack

题目 描述 Given the weights and values of N items, put a subset of items into a knapsack of capacity C to get the maximum total value in the knapsack. The total weight of items in the knapsack does not exceed C. 输入 First line: two positive integers N (…...

深入探索离散 Hopfield 神经网络

一、离散 Hopfield 神经网络的起源与发展 离散 Hopfield 神经网络由约翰・霍普菲尔德在 1982 年提出&#xff0c;这一创新性的成果在当时引起了广泛关注&#xff0c;成为早期人工神经网络的重要代表之一。 在那个时期&#xff0c;人工神经网络的发展还处于相对初级的阶段。霍…...

[智能车摄像头是一种安装在汽车上用于辅助驾驶和提高安全性的重要设备]

智能车摄像头是一种安装在汽车上用于辅助驾驶和提高安全性的重要设备。它们通常包括几个不同类型&#xff0c;如前视摄像头、环视摄像头、行车记录仪等。这些摄像头的主要功能有&#xff1a; 前视摄像头&#xff08;Forward Camera&#xff09;&#xff1a;用于提供驾驶员前方…...

前端vue 列表中回显并下拉选择修改标签

1&#xff0c;vue数据列表中进行回显状态并可以在下拉框中选择修改&#xff0c;效果如下 2&#xff0c;vue 页面关键代码 <el-table-column label"审核" align"center" class-name"small-padding fixed-width" prop"status" >&…...

hbase未来的发展趋势

HBase 作为一个开源的分布式、可伸缩的 NoSQL 数据库,依托于 Hadoop 生态系统,以处理海量结构化数据为优势。随着大数据技术的发展,HBase 的发展趋势主要体现在以下几个方面: 1. 与云计算深度集成 随着企业向云迁移,HBase 也正在越来越多地部署在云环境中。云服务商(如 …...

Rust 语言学习笔记(二)

再继续快速学习一下 Rust 的以下几个知识点&#xff0c;就可以开始着手做点小工具了 基本数据类型复合数据类型基本的流程控制 Rust 设计为有效使用内存考虑的&#xff0c;它提供了非常细力度的数据类型&#xff0c;如整数分为有无符号&#xff0c;宽度从 8 位到 128 位&…...

【postman】怎么通过curl看请求报什么错

获取现成的curl方式&#xff1a; 1&#xff0c;拿别人给的curl 2&#xff0c;手机app界面通过charles抓包&#xff0c;点击接口复制curl 3&#xff0c;浏览器界面-开发者工具-选中接口复制curl 拿到curl之后打开postman&#xff0c;点击import&#xff0c;粘贴curl点击send&am…...

Python 编程入门指南(一)

1. Python 简介 Python是一种广泛使用的高级编程语言,因其简洁的语法和强大的功能而备受欢迎。Python由Guido van Rossum于20世纪90年代初设计,旨在提供易于阅读和编写的代码,适合从初学者到专业开发者的各个水平。它是一种解释型语言,这意味着在编写和执行代码之间不需要…...

macOS 设置固定IP

文章目录 以太网Wifi![请添加图片描述](https://i-blog.csdnimg.cn/direct/65546e966cae4b2fa93ec9f0f87009d8.png) 基于 macOS 15.1 以太网 Wifi...

redis实现消息队列的几种方式

一、了解 众所周知&#xff0c;redis是我们日常开发过程中使用最多的非关系型数据库&#xff0c;也是消息中间件。实际上除了常用的rabbitmq、rocketmq、kafka消息队列&#xff08;大家自己下去研究吧~模式都是通用的&#xff09;&#xff0c;我们也能使用redis实现消息队列。…...

debian 系统更新升级

系统升级能够有效避免漏洞导致的损害 不过需要做好提前和后续的测试&#xff0c;避免现有运行程序的错误。 debian安装参考&#xff1a;链接 一、清理过时和不再使用的源 1.清理源 vi /etc/apt/sources.list2.在下面的文件夹下清理不需要的 cd /etc/apt/sources.list.d二、…...

STM32学习笔记-----UART的概念

在 STM32 中&#xff0c;UART&#xff08;Universal Asynchronous Receiver/Transmitter&#xff09;是一种常用的串行通信接口&#xff0c;广泛应用于嵌入式系统中。STM32 提供了丰富的硬件资源来支持 UART 通信&#xff0c;可以通过标准库&#xff08;STM32 HAL 或者标准外设…...