CNN代码实战
CNN的原理
从 DNN 到 CNN
(1)卷积层与汇聚
⚫ 深度神经网络 DNN 中,相邻层的所有神经元之间都有连接,这叫全连接;卷积神经网络 CNN 中,新增了卷积层(Convolution)与汇聚(Pooling)。
⚫ DNN 的全连接层对应 CNN 的卷积层,汇聚是与激活函数类似的附件;单个卷积层的结构是:卷积层-激活函数-(汇聚),其中汇聚可省略。
(2)CNN:专攻多维数据
在深度神经网络 DNN 课程的最后一章,使用 DNN 进行了手写数字的识别。但是,图像至少就有二维,向全连接层输入时,需要多维数据拉平为 1 维数据,这样一来,图像的形状就被忽视了,很多特征是隐藏在空间属性里的,而卷积层可以保持输入数据的维数不变,当输入数据是二维图像时,卷积层会以多维数据的形式接收输入数据,并同样以多维数据的形式输出至下一层

导包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
制作数据集
# 制作数据集
# 数据集转换参数
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081)
])
# 下载训练集与测试集
train_Data = datasets.MNIST(
root = 'D:/Postgraduate/CNN', # 下载路径
train = True, # 是 train 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
test_Data = datasets.MNIST(
root = 'D:/Postgraduate/CNN', # 下载路径
train = False, # 是 test 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
# 批次加载器
train_loader = DataLoader(train_Data, shuffle=True, batch_size=256)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=256)
训练网络
class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 120, kernel_size=5), nn.Tanh(),nn.Flatten(),nn.Linear(120, 84), nn.Tanh(),nn.Linear(84, 10)
)def forward(self, x):y = self.net(x)return y
# 创建子类的实例,并搬到 GPU 上
model = CNN().to('cuda:0')
# 训练网络
# 损失函数的选择
loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数
# 优化算法的选择
learning_rate = 0.9 # 设置学习率
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate,
)
# 训练网络
epochs = 5
losses = [] # 记录损失函数变化的列表
for epoch in range(epochs):for (x, y) in train_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)loss = loss_fn(Pred, y) # 计算损失函数losses.append(loss.item()) # 记录损失函数的变化optimizer.zero_grad() # 清理上一轮滞留的梯度loss.backward() # 一次反向传播optimizer.step() # 优化内部参数
Fig = plt.figure()
plt.plot(range(len(losses)), losses)
plt.show()
测试网络
# 测试网络
correct = 0
total = 0
with torch.no_grad(): # 该局部关闭梯度计算功能for (x, y) in test_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum( (predicted == y) )total += y.size(0)
print(f'测试集精准度: {100*correct/total} %')
使用网络
# 保存网络
torch.save(model, 'CNN.path')
new_model = torch.load('CNN.path')
完整代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt# 制作数据集
# 数据集转换参数
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081)
])
# 下载训练集与测试集
train_Data = datasets.MNIST(
root = 'D:/Postgraduate/python_project/CNN', # 下载路径
train = True, # 是 train 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
test_Data = datasets.MNIST(
root = 'D:/Postgraduate/python_project/CNN', # 下载路径
train = False, # 是 test 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
# 批次加载器
train_loader = DataLoader(train_Data, shuffle=True, batch_size=256)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=256)class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 120, kernel_size=5), nn.Tanh(),nn.Flatten(),nn.Linear(120, 84), nn.Tanh(),nn.Linear(84, 10)
)def forward(self, x):y = self.net(x)return y
# 创建子类的实例,并搬到 GPU 上
model = CNN().to('cuda:0')
# 训练网络
# 损失函数的选择
loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数
# 优化算法的选择
learning_rate = 0.9 # 设置学习率
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate,
)
# 训练网络
epochs = 5
losses = [] # 记录损失函数变化的列表
for epoch in range(epochs):for (x, y) in train_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)loss = loss_fn(Pred, y) # 计算损失函数losses.append(loss.item()) # 记录损失函数的变化optimizer.zero_grad() # 清理上一轮滞留的梯度loss.backward() # 一次反向传播optimizer.step() # 优化内部参数
Fig = plt.figure()
plt.plot(range(len(losses)), losses)
plt.show()# 测试网络
correct = 0
total = 0
with torch.no_grad(): # 该局部关闭梯度计算功能for (x, y) in test_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum( (predicted == y) )total += y.size(0)
print(f'测试集精准度: {100*correct/total} %')# 保存网络
torch.save(model, 'CNN.path')
new_model = torch.load('CNN.path')
运行截图


相关文章:
CNN代码实战
CNN的原理 从 DNN 到 CNN (1)卷积层与汇聚 ⚫ 深度神经网络 DNN 中,相邻层的所有神经元之间都有连接,这叫全连接;卷积神经网络 CNN 中,新增了卷积层(Convolution)与汇聚(…...
迁移学习代码复现
一、前言 说来可能令人难以置信,迁移学习技术在实践中是非常简单的,我们仅需要保留训练好的神经网络整体或者部分网络,再在使用迁移学习的情况下把保留的模型重新加载到内存中,就完成了迁移的过程。之后,我们就可以像训练普通神经网络那样训练迁移过来的神经网络了。 我们…...
Elasticsearch(ES)常用命令
常用运维命令 一、基本命令1.1、查看集群的健康状态1.2、查看节点信息1.3、查看索引列表1.4、创建索引1.5、删除索引1.6、关闭索引1.7、打开索引1.8、查看集群资源使用情况(各个节点的状态,包括磁盘,heap,ram的使用情况࿰…...
C/C++ 不定参函数
C语言不定参函数 函数用法总结 Va_list 作用:类型定义,生命一个变量,该变量被用来访问传递给不定参函数的可变参数列表用法:供后续函数进调用,通过该变量访问参数列表 typedefchar* va_list; va_start 作用ÿ…...
C语言——函数专题
1.概念 在C语言中引入函数的概念,有些翻译为子程序。C语言中的函数就是一个完成某项特定任务的一小段代码,这个代码是有特殊的写法和调用方法的。一般我们可以分为两种函数:库函数和自定义函数。 2.库函数 C语言国际标准ANSIC规定了一些常…...
springboot打可执行jar包
1. pom文件如下 <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><m…...
【SQL】科目种类
目录 题目 分析 代码 题目 表: Teacher ------------------- | Column Name | Type | ------------------- | teacher_id | int | | subject_id | int | | dept_id | int | ------------------- 在 SQL 中,(subject_id, dept_id) 是该表的主键。 该表…...
【深度学习】【语音】TTS,最新TTS模型概览,扩散模型TTS,MeloTTS、StyleTTS2、Matcha-TTS
文章目录 基础介绍对比基础介绍 MeloTTS: MeloTTS 是 MyShell.ai 开发的一个多语言语音合成模型,支持包括英语、西班牙语、法语、中文、日语和韩语等多种语言。它以高质量的语音合成为特色,尤其擅长处理中英混合内容。该模型优化了在 CPU 上的实时推理能力,使其在多种应用场…...
【论文笔记】LION: Linear Group RNN for 3D Object Detection in Point Clouds
原文链接:https://arxiv.org/abs/2407.18232 简介:Transformer在3D点云感知任务中有二次复杂度,难以进行长距离关系建模。线性RNN则计算复杂度较低,适合进行长距离关系建模。本文提出基于窗口的网络线性组RNN(即对分组…...
打造高可用集群的基石:深度解析Keepalived实践与优化
高可用集群 集群类型 集群类型主要分为负载均衡集群(LB)、高可用集群(HA)和高性能计算集群(HPC)三大类。每种集群类型都有其特定的应用场景和优势。 1. 负载均衡集群(LB) 负载均衡集…...
Web大学生网页作业成品——环保主题介绍网页网站设计与实现(HTML+CSS)(5个页面)
🎉🎉🎉 常见网页设计作业题材有**汽车、环保、明星、文化、国家、抗疫、景点、人物、体育、植物、公益、图书、节日、游戏、商城、旅游、家乡、学校、电影、动漫、非遗、动物、个人、企业、美食、婚纱、其他**等网页设计题目, 可满足大学生网…...
Qt登录窗口设计
widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QIcon> //图标类 #include <QPushButton> #include <QLineEdit> //行编辑 #include <QLabel> #include <QTextEdit> #include <QMovie>class Widge…...
探索数据矿藏:我的AI大模型与数据挖掘实战经验分享
🚀 探索数据矿藏:我的AI大模型与数据挖掘实战经验分享 💖 前言:数据的金矿,AI的翅膀 在人工智能的浪潮中,我有幸作为项目负责人,带领团队深入挖掘数据的潜力,利用AI大模型的力量&am…...
linux C语言strcat函数及相关函数
讯飞星火 strcat函数是C语言中的一个字符串处理函数,用于将一个字符串追加到另一个字符串的末尾。它的原型如下: char *strcat(char *dest, const char *src);参数说明: dest:目标字符串,即要将src字符串追加到这个字符串的末尾。src:源字符串,即要追加到dest字符串末…...
使用 sort 进行文本文件处理
sort 使用经验笔记 一、简介 sort 是一个命令行工具,用于对文本文件中的行进行排序。它广泛应用于 Unix/Linux 系统中,可以用来对文件的内容进行简单的排序操作,也可以与其他命令结合使用来完成更复杂的任务。 二、基本用法 排序文件: sor…...
HarmonyOS笔记4:从云数据库获取数据
移动应用获取数据的方式主要有: 1.从网络中获取数据接口API。 2.从华为云数据库获取云数据库的资源。 3.从移动终端直接获取本地的数据 在HarmonyOS笔记3中已经完成了方式一从网络中获取数据接口API的方式。在本篇笔记中,将讨论从云数据库中获取数据。 因…...
QT5生成独立运行的exe文件
目录 1 生成独立运行的exe文件1.1 设置工程Release版本可执行文件存储路径1.2 将工程编译成Release版本 2 使用QT5自带的windeployqt拷贝软件运行依赖项3 将程序打包成一个独立的可执行软件exe4 解决QT5 This application failed to start because no Qt platform plugin could…...
LabVIEW光纤水听器闭环系统
开发了一种利用LabVIEW软件开发的干涉型光纤水听器闭环工作点控制系统。该系统通过调节光源频率和非平衡干涉仪的光程差,实现了工作点的精确控制,从而提高系统的稳定性和检测精度,避免了使用压电陶瓷,使操作更加简便。 项目背景 …...
Shell——流程控制语句(if、case、for、while等)
在 Shell 编程中,流程控制语句用于控制脚本的执行顺序和逻辑。这些语句包括 if、case、for、while 等,它们的使用可以使脚本实现更复杂的逻辑。以下是它们的详细说明和语法结构: 1. if 语句 if 语句用于条件判断,执行符合条件的…...
【redis的大key问题】
在使用 Redis 的过程中,如果未能及时发现并处理 Big keys(下文称为“大Key”),可能会导致服务性能下降、用户体验变差,甚至引发大面积故障。 本文将介绍大Key产生的原因、其可能引发的问题及如何快速找出大Key并将其优…...
高频面试之3Zookeeper
高频面试之3Zookeeper 文章目录 高频面试之3Zookeeper3.1 常用命令3.2 选举机制3.3 Zookeeper符合法则中哪两个?3.4 Zookeeper脑裂3.5 Zookeeper用来干嘛了 3.1 常用命令 ls、get、create、delete、deleteall3.2 选举机制 半数机制(过半机制࿰…...
在四层代理中还原真实客户端ngx_stream_realip_module
一、模块原理与价值 PROXY Protocol 回溯 第三方负载均衡(如 HAProxy、AWS NLB、阿里 SLB)发起上游连接时,将真实客户端 IP/Port 写入 PROXY Protocol v1/v2 头。Stream 层接收到头部后,ngx_stream_realip_module 从中提取原始信息…...
第25节 Node.js 断言测试
Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...
spring:实例工厂方法获取bean
spring处理使用静态工厂方法获取bean实例,也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下: 定义实例工厂类(Java代码),定义实例工厂(xml),定义调用实例工厂ÿ…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...
什么是Ansible Jinja2
理解 Ansible Jinja2 模板 Ansible 是一款功能强大的开源自动化工具,可让您无缝地管理和配置系统。Ansible 的一大亮点是它使用 Jinja2 模板,允许您根据变量数据动态生成文件、配置设置和脚本。本文将向您介绍 Ansible 中的 Jinja2 模板,并通…...
【Linux】Linux 系统默认的目录及作用说明
博主介绍:✌全网粉丝23W,CSDN博客专家、Java领域优质创作者,掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域✌ 技术范围:SpringBoot、SpringCloud、Vue、SSM、HTML、Nodejs、Python、MySQL、PostgreSQL、大数据、物…...
【 java 虚拟机知识 第一篇 】
目录 1.内存模型 1.1.JVM内存模型的介绍 1.2.堆和栈的区别 1.3.栈的存储细节 1.4.堆的部分 1.5.程序计数器的作用 1.6.方法区的内容 1.7.字符串池 1.8.引用类型 1.9.内存泄漏与内存溢出 1.10.会出现内存溢出的结构 1.内存模型 1.1.JVM内存模型的介绍 内存模型主要分…...
人工智能--安全大模型训练计划:基于Fine-tuning + LLM Agent
安全大模型训练计划:基于Fine-tuning LLM Agent 1. 构建高质量安全数据集 目标:为安全大模型创建高质量、去偏、符合伦理的训练数据集,涵盖安全相关任务(如有害内容检测、隐私保护、道德推理等)。 1.1 数据收集 描…...
从“安全密码”到测试体系:Gitee Test 赋能关键领域软件质量保障
关键领域软件测试的"安全密码":Gitee Test如何破解行业痛点 在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的"神经中枢"。从国防军工到能源电力,从金融交易到交通管控,这些关乎国计民生的关键领域…...
