当前位置: 首页 > news >正文

深度学习-卷积神经网络CNN

案例-图像分类

网络结构: 卷积+BN+激活+池化

数据集介绍

CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32×32×3。下图列举了10个类,每一类随机展示了10张图片:

特征图计算

在卷积层和池化层结束后, 将特征图变形成一行n列数据, 计算特征图进行变化, 映射到全连接层时输入层特征为最后一层卷积层经池化后的特征图各维度相乘

具体流程-# Acc: 0.728

# 导包
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose  # Compose: 数据增强(扩充数据集)
import time
import matplotlib.pyplot as plt
​
batch_size = 16
​
​
# 创建数据集
def create_dataset():torch.manual_seed(21)train = CIFAR10(root='data',train=True,transform=Compose([ToTensor()]))test = CIFAR10(root='data',train=False,transform=Compose([ToTensor()]))return train, test
​
​
# 创建模型
class ImgCls(nn.Module):# 定义网络结构def __init__(self):super(ImgCls, self).__init__()# 定义网络层:卷积层+池化层self.conv1 = nn.Conv2d(3, 16, stride=1, kernel_size=3)self.batch_norm_layer1 = nn.BatchNorm2d(num_features=16, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv2 = nn.Conv2d(16, 32, stride=1, kernel_size=3)self.batch_norm_layer2 = nn.BatchNorm2d(num_features=32, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv3 = nn.Conv2d(32, 64, stride=1, kernel_size=3)self.batch_norm_layer3 = nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv4 = nn.Conv2d(64, 128, stride=1, kernel_size=2)self.batch_norm_layer4 = nn.BatchNorm2d(num_features=128, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv5 = nn.Conv2d(128, 256, stride=1, kernel_size=2)self.batch_norm_layer5 = nn.BatchNorm2d(num_features=256, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool5 = nn.MaxPool2d(kernel_size=2, stride=1)
​# 全连接层self.linear1 = nn.Linear(1024, 2048)self.linear2 = nn.Linear(2048, 1024)self.linear3 = nn.Linear(1024, 512)self.linear4 = nn.Linear(512, 256)self.linear5 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)
​# 定义前向传播def forward(self, x):# 第1层: 卷积+BN+激活+池化x = self.conv1(x)x = self.batch_norm_layer1(x)x = torch.rrelu(x)x = self.pool1(x)
​# 第2层: 卷积+BN+激活+池化x = self.conv2(x)x = self.batch_norm_layer2(x)x = torch.rrelu(x)x = self.pool2(x)
​# 第3层: 卷积+BN+激活+池化x = self.conv3(x)x = self.batch_norm_layer3(x)x = torch.rrelu(x)x = self.pool3(x)
​# 第4层: 卷积+BN+激活+池化x = self.conv4(x)x = self.batch_norm_layer4(x)x = torch.rrelu(x)x = self.pool4(x)
​# 第5层: 卷积+BN+激活+池化x = self.conv5(x)x = self.batch_norm_layer5(x)x = torch.rrelu(x)x = self.pool5(x)
​# 将特征图做成以为向量的形式:相当于特征向量x = x.reshape(x.size(0), -1)  # 将3维特征图转化为1维向量(1, n)
​# 全连接层x = torch.rrelu(self.linear1(x))x = torch.rrelu(self.linear2(x))x = torch.rrelu(self.linear3(x))x = torch.rrelu(self.linear4(x))x = torch.rrelu(self.linear5(x))# 返回输出结果return self.out(x)
​
​
# 训练
def train(model, train_dataset, epochs):torch.manual_seed(21)loss = nn.CrossEntropyLoss()opt = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)loss_total = 0iter = 0stat_time = time.time()for x, y in dataloader:output = model(x.to(device))loss_value = loss(output, y.to(device))opt.zero_grad()loss_value.backward()opt.step()loss_total += loss_value.item()iter += 1print(f'epoch:{epoch + 1:4d}, loss:{loss_total / iter:6.4f}, time:{time.time() - stat_time:.2f}s')torch.save(model.state_dict(), 'model/img_cls_model.pth')
​
​
# 测试
def test(valid_dataset, model, batch_size):# 构建数据加载器dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
​# 计算精度total_correct = 0# 遍历每个batch的数据,获取预测结果,计算精度for x, y in dataloader:output = model(x.to(device))y_pred = torch.argmax(output, dim=-1)total_correct += (y_pred == y.to(device)).sum()# 打印精度print(f'Acc: {(total_correct.item() / len(valid_dataset))}')
​
​
if __name__ == '__main__':batch_size = 16device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 获取数据集train_data, test_data = create_dataset()
​# # 查看数据集# print(f'数据集类别: {train_data.class_to_idx}')# print(f'训练集: {train_data.data.shape}')# print(f'验证集: {test_data.data.shape}')# print(f'类别数量: {len(np.unique(train_data.targets))}')# # 展示图像# plt.figure(figsize=(8, 8))# plt.imshow(train_data.data[0])# plt.title(train_data.classes[train_data.targets[0]])# plt.show()
​# 实例化模型model = ImgCls().to(device)
​# 查看网络结构summary(model, (3, 32, 32), device='cuda', batch_size=batch_size)
​# 模型训练train(model, train_data, epochs=60)# 加载训练好的模型参数model.load_state_dict(torch.load('model/img_cls_model.pth'))model.eval()# 模型评估test(test_data, model, batch_size=16)   # Acc: 0.728
​

调整网络结构

第一次调整: 训练50轮, Acc: 0.71

第二次调整: 训练30轮, Acc:0.7351

第三次调整: batch_size=8, epoch=50 => Acc: 0.7644

# 导包
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose  # Compose: 数据增强(扩充数据集)
import time
import matplotlib.pyplot as plt
​
batch_size = 16
​
​
# 创建数据集
def create_dataset():torch.manual_seed(21)train = CIFAR10(root='data',train=True,transform=Compose([ToTensor()]))test = CIFAR10(root='data',train=False,transform=Compose([ToTensor()]))return train, test
​
​
# 创建模型
class ImgCls(nn.Module):# 定义网络结构def __init__(self):super(ImgCls, self).__init__()# 定义网络层:卷积层+池化层self.conv1 = nn.Conv2d(3, 16, stride=1, kernel_size=3, padding=1)self.batch_norm_layer1 = nn.BatchNorm2d(num_features=16, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv2 = nn.Conv2d(16, 32, stride=1, kernel_size=3, padding=1)self.batch_norm_layer2 = nn.BatchNorm2d(num_features=32, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
​self.conv3 = nn.Conv2d(32, 64, stride=1, kernel_size=3, padding=1)self.batch_norm_layer3 = nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool3 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv4 = nn.Conv2d(64, 128, stride=1, kernel_size=3, padding=1)self.batch_norm_layer4 = nn.BatchNorm2d(num_features=128, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool4 = nn.MaxPool2d(kernel_size=2, stride=1)
​self.conv5 = nn.Conv2d(128, 256, stride=1, kernel_size=3)self.batch_norm_layer5 = nn.BatchNorm2d(num_features=256, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
​# 全连接层self.linear1 = nn.Linear(1024, 2048)self.linear2 = nn.Linear(2048, 1024)self.linear3 = nn.Linear(1024, 512)self.linear4 = nn.Linear(512, 256)self.linear5 = nn.Linear(256, 128)self.out = nn.Linear(128, 10)
​# 定义前向传播def forward(self, x):# 第1层: 卷积+BN+激活+池化x = self.conv1(x)x = self.batch_norm_layer1(x)x = torch.relu(x)x = self.pool1(x)
​# 第2层: 卷积+BN+激活+池化x = self.conv2(x)x = self.batch_norm_layer2(x)x = torch.relu(x)x = self.pool2(x)
​# 第3层: 卷积+BN+激活+池化x = self.conv3(x)x = self.batch_norm_layer3(x)x = torch.relu(x)x = self.pool3(x)
​# 第4层: 卷积+BN+激活+池化x = self.conv4(x)x = self.batch_norm_layer4(x)x = torch.relu(x)x = self.pool4(x)
​# 第5层: 卷积+BN+激活+池化x = self.conv5(x)x = self.batch_norm_layer5(x)x = torch.rrelu(x)x = self.pool5(x)
​# 将特征图做成以为向量的形式:相当于特征向量x = x.reshape(x.size(0), -1)  # 将3维特征图转化为1维向量(1, n)
​# 全连接层x = torch.relu(self.linear1(x))x = torch.relu(self.linear2(x))x = torch.relu(self.linear3(x))x = torch.relu(self.linear4(x))x = torch.rrelu(self.linear5(x))# 返回输出结果return self.out(x)
​
​
# 训练
def train(model, train_dataset, epochs):torch.manual_seed(21)loss = nn.CrossEntropyLoss()opt = optim.Adam(model.parameters(), lr=1e-4)for epoch in range(epochs):dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)loss_total = 0iter = 0stat_time = time.time()for x, y in dataloader:output = model(x.to(device))loss_value = loss(output, y.to(device))opt.zero_grad()loss_value.backward()opt.step()loss_total += loss_value.item()iter += 1print(f'epoch:{epoch + 1:4d}, loss:{loss_total / iter:6.4f}, time:{time.time() - stat_time:.2f}s')torch.save(model.state_dict(), 'model/img_cls_model1.pth')
​
​
# 测试
def test(valid_dataset, model, batch_size):# 构建数据加载器dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
​# 计算精度total_correct = 0# 遍历每个batch的数据,获取预测结果,计算精度for x, y in dataloader:output = model(x.to(device))y_pred = torch.argmax(output, dim=-1)total_correct += (y_pred == y.to(device)).sum()# 打印精度print(f'Acc: {(total_correct.item() / len(valid_dataset))}')
​
​
if __name__ == '__main__':batch_size = 8device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 获取数据集train_data, test_data = create_dataset()
​# # 查看数据集# print(f'数据集类别: {train_data.class_to_idx}')# print(f'训练集: {train_data.data.shape}')# print(f'验证集: {test_data.data.shape}')# print(f'类别数量: {len(np.unique(train_data.targets))}')# # 展示图像# plt.figure(figsize=(8, 8))# plt.imshow(train_data.data[0])# plt.title(train_data.classes[train_data.targets[0]])# plt.show()
​# 实例化模型model = ImgCls().to(device)
​# 查看网络结构summary(model, (3, 32, 32), device='cuda', batch_size=batch_size)
​# 模型训练train(model, train_data, epochs=50)# 加载训练好的模型参数model.load_state_dict(torch.load('model/img_cls_model1.pth', weights_only=True))model.eval()# 模型评估test(test_data, model, batch_size=16)   # Acc: 0.7644
​

相关文章:

深度学习-卷积神经网络CNN

案例-图像分类 网络结构: 卷积BN激活池化 数据集介绍 CIFAR-10数据集5万张训练图像、1万张测试图像、10个类别、每个类别有6k个图像,图像大小32323。下图列举了10个类,每一类随机展示了10张图片: 特征图计算 在卷积层和池化层结束后, 将特征…...

241114.学习日志——[CSDIY] [Cpp]零基础速成 [03]

​ CSDIY:这是一个非科班学生的努力之路,从今天开始这个系列会长期更新,(最好做到日更),我会慢慢把自己目前对CS的努力逐一上传,帮助那些和我一样有着梦想的玩家取得胜利!&#xff0…...

大模型研究报告 | 2024年中国金融大模型产业发展洞察报告|附34页PDF文件下载

随着生成算法、预训练模型、多模态数据分析等AI技术的聚集融合,AIGC技术的实践效用迎来了行业级大爆发。通用大模型技术的成熟推动了新一轮行业生产力变革,在投入提升与政策扶植的双重作用下,以大模型技术为底座、结合专业化金融能力的金融大…...

数据库SQL——什么是实体-联系模型(E-R模型)?

目录 什么是实体-联系模型? 1.实体集 2.联系集 3.映射基数 一对一(1:1) 一对多(1:n) 多对一(n:1) 多对多(m:n) 全部参与: 4.主码 弱实体集&#xf…...

在 MySQL 8.0 中,SSL 解密失败,在使用 SSL 加密连接时出现了问题

在 MySQL 8.0 中,SSL 解密失败通常指的是在使用 SSL 加密连接时出现了问题,导致无法建立加密通信。这个问题可能由多种原因引起,下面是一些常见的原因和排查方法: 1. 证书配置问题 确保您在 MySQL 配置中使用了正确的 SSL 证书和…...

React Native 全栈开发实战班 - 第四部分:用户界面进阶之动画效果实现

在移动应用中,动画效果 是提升用户体验的重要手段。合理的动画设计可以增强应用的交互性、流畅性和视觉吸引力。React Native 提供了多种实现动画的方式,包括内置的 Animated API、LayoutAnimation 以及第三方库(如 react-native-reanimated&…...

【CICD】GitLab Runner 和执行器(Executor

GitLab Runner 和执行器(Executor)是 GitLab CI/CD 管道中的两个重要组成部分。理解它们之间的关系有助于更好地配置和使用 CI/CD 流水线。runer是gitlab的ci-agent对接gitlab,而执行器是接受runer下发的ci的任务来干活的。也就是说gitrunner…...

实用教程:如何无损修改MP4视频时长

如何在UltraEdit中搜索MP4文件中的“mvhd”关键字 引言 在视频编辑和分析领域,有时我们需要深入到视频文件的底层结构中去。UltraEdit(UE)和UEStudio作为强大的文本编辑器,允许我们以十六进制模式打开和搜索MP4文件。本文将指导…...

mysqldump命令搭配source命令完成数据库迁移备份

mysqldump 命令使用 需保证mysqld在运行中, 这个命令的目的是将数据库导出到文件中,例如 mysqldump -uusername -ppassword database > db.sql 注意该命令不是在MySQL客户端(即MySQL命令行)执行的,而是在系统命…...

生信:TCGA学习(R、RStudio安装与下载、常用语法与常用快捷键)

前置环境 macOS系统,已安装homebrew且会相关命令。 近期在整理草稿区,所以放出该贴。 R语言、RStudio、R包安装 R语言安装 brew install rRStudio安装 官网地址:https://posit.co/download/rstudio-desktop/ R包下载 注意R语言环境自带…...

十三、注解配置SpringMVC

文章目录 1. 创建初始化类,代替web.xml2. 创建SpringConfig配置类,代替spring的配置文件3. 创建WebConfig配置类,代替SpringMVC的配置文件4. 测试功能 1. 创建初始化类,代替web.xml 2. 创建SpringConfig配置类,代替spr…...

为什么海外服务器IP会被封

海外服务器因为免备案而备受用户欢迎,近年来租用海外服务器的用户也越来越多,自然也可能会出现一些问题。 如果服务器IP被封,在该服务器下的所有业务都无法访问,对自己和对用户来说都会有较大的影响。因此,我们应做好相…...

图像处理技术椒盐噪声

椒盐噪声,也称为脉冲噪声,是图像中经常见到的一种噪声。它是一种随机出现的白点或者黑点,可能是亮的区域有黑色像素或是在暗的区域有白色像素(或是两者皆有)。这些白点和黑点会在图像中随机分布,导致图像中…...

[笔记]L6599的极限工作条件考量

0.名词 OTP over tempature protect.OCP over current protectOVP over voltage protectBrownout Protection Undervoltage Protection可能需要考虑hysteresis response.因为要考虑一些高频干扰 1.基本的过流保护逻辑 参考:ST L6599 器件手册 LLC开关电源&#…...

机器学习基础04

目录 1.朴素贝叶斯-分类 1.1贝叶斯分类理论 1.2条件概率 1.3全概率公式 1.4贝叶斯推断 1.5朴素贝叶斯推断 1.6拉普拉斯平滑系数 1.7API 2.决策树-分类 2.1决策树 2.2基于信息增益的决策树建立 2.2.1信息熵 2.2.2信息增益 2.2.3信息增益决策树建立步骤 2.3基于基…...

Ubuntu 20.04 配置开发环境(持续更新)

搜狗输入法不能显示中文 sudo apt install libqt5qml5 libgsettings-qt1 sudo apt install libqt5qml5 libqt5quick5 libqt5quickwidgets5 qml-module-qtquick2 编译环境配置 sudo apt-get update #base tools of ubuntu sudo apt install net-tools gitk tree vim termina…...

Rocky9/Ubuntu使用pip安装python的库mysqlclient失败解决方式

# Rocky9 直接使用pip安装mysqlclient会出现缺少依赖,需要先安装mysql-devel相关依赖。由于rocky9用MariaDB替代了MySQL,所以我们可以通过安装mariadb-devel来安装所需要的依赖。 如果Rocky9已经开启了powertool repo可以直接使用下面命令安装 dnf in…...

探索 HTML 和 CSS 实现的 3D旋转相册

效果演示 这段HTML与CSS代码创建了一个包含10张卡片的3D旋转效果&#xff0c;每张卡片都有自己的边框颜色和图片。通过CSS的3D变换和动画&#xff0c;实现了一个动态的旋转展示效果 HTML <div class"wrapper"><div class"inner" style"-…...

OpenJudge_ 简单英文题_04:0/1 Knapsack

题目 描述 Given the weights and values of N items, put a subset of items into a knapsack of capacity C to get the maximum total value in the knapsack. The total weight of items in the knapsack does not exceed C. 输入 First line: two positive integers N (…...

深入探索离散 Hopfield 神经网络

一、离散 Hopfield 神经网络的起源与发展 离散 Hopfield 神经网络由约翰・霍普菲尔德在 1982 年提出&#xff0c;这一创新性的成果在当时引起了广泛关注&#xff0c;成为早期人工神经网络的重要代表之一。 在那个时期&#xff0c;人工神经网络的发展还处于相对初级的阶段。霍…...

2026年婚礼背景音乐素材下载网站TOP5:从版权、曲库到实用场景全面评测

引言&#xff1a;为什么婚礼背景音乐素材越来越需要“可商用、可溯源、可快速下载” 2026年&#xff0c;婚礼内容已经不再只是一支婚礼纪录片&#xff0c;而是拆分成婚礼预告片、接亲快剪、仪式短片、First Look、婚礼跟拍花絮、短视频平台竖版成片、婚庆公司案例展示等多个内…...

告别手动操作:用Python自动化COMSOL仿真的3个关键突破

告别手动操作&#xff1a;用Python自动化COMSOL仿真的3个关键突破 【免费下载链接】MPh Pythonic scripting interface for Comsol Multiphysics 项目地址: https://gitcode.com/gh_mirrors/mp/MPh 你是否也曾为COMSOL的重复性仿真任务感到疲惫&#xff1f;每天花费数小…...

告别警告与强制刷新:Unity聊天对话框自适应布局的纯净实现方案

1. 为什么需要纯净的自适应聊天对话框&#xff1f; 在Unity中实现一个聊天对话框看似简单&#xff0c;但要让它在各种情况下都能完美自适应却是个技术活。很多开发者都遇到过这样的困扰&#xff1a;明明按照教程加了Content Size Fitter和LayoutGroup&#xff0c;UI却总是出现奇…...

【Java杂项】为什么 b += 1 可以,但 b = b + 1 会报错?类型提升与复合赋值详解

【Java杂项】为什么 b 1 可以&#xff0c;但 b b 1 会报错&#xff1f;复合赋值与类型提升讲清楚前言一、先给结论&#xff1a;它不是简单的文本替换二、先看认知冲突2.1 普通赋值为什么报错2.2 复合赋值为什么能通过三、类型提升到底是什么3.1 常见类型提升结果3.2 为什么小…...

A-59F所有应用模式说明

A-59F 是一款高集成语音处理模组&#xff0c;一体化实现 AI ENC 降噪、AEC 回音消除、扩音防啸叫、BF 波束拾音 四大核心能力。支持模拟 / 数字麦克风、模拟 / I2S 数字音频接口&#xff0c;邮票孔 SMT 封装&#xff0c;体积小巧、易嵌入&#xff0c;可大幅简化音频电路&#x…...

【独家逆向分析】:Perplexity招聘页埋点数据如何被提取?附Python自动化脚本(限24小时领取)

更多请点击&#xff1a; https://kaifayun.com 第一章&#xff1a;Perplexity薪资数据查询 Perplexity 作为一家以 AI 原生搜索和研究工具著称的科技公司&#xff0c;其薪酬结构长期未公开披露&#xff0c;但可通过多源交叉验证方式获取合理估算。目前主流可信渠道包括 Levels…...

从PyCharm到ArcGIS工具箱:把你的Python地理处理脚本‘打包’成专业工具的保姆级指南

从PyCharm到ArcGIS工具箱&#xff1a;Python地理处理脚本的专业化封装实战 当你在PyCharm中完成了一个完美运行的地理处理脚本&#xff0c;接下来最自然的想法就是让它能被更多非技术同事直接使用。本文将带你跨越开发环境与生产环境的鸿沟&#xff0c;将一个孤立的Python脚本转…...

灰度发布与流量切换

Skeyevss FAQ&#xff1a;灰度发布与流量切换 试用安装包下载 | SMS | 在线演示 项目地址&#xff1a;https://github.com/openskeye/go-vss 1. 目标 新版本 先小流量验证&#xff0c;指标正常再全量&#xff1b;出问题 快速回滚。对 SIP 类系统&#xff0c;还要考虑 会话粘…...

别再死记硬背了!用这 5 个核心功能理解 Final Cut Pro 的设计哲学

Final Cut Pro 的设计哲学&#xff1a;5个核心功能如何重塑你的剪辑思维 当你第一次打开Final Cut Pro&#xff08;简称FCPX&#xff09;&#xff0c;可能会被它与其他剪辑软件截然不同的界面所困惑。这不是一个需要你适应传统时间线的工具&#xff0c;而是一个重新思考剪辑流程…...

为什么你的Perplexity查不到正确代码?——基于127个失败Query的日志审计报告(附修复清单)

更多请点击&#xff1a; https://codechina.net 第一章&#xff1a;为什么你的Perplexity查不到正确代码&#xff1f;——基于127个失败Query的日志审计报告&#xff08;附修复清单&#xff09; 我们对127条在Perplexity平台中返回空结果、过时答案或完全偏离编程意图的用户Qu…...