当前位置: 首页 > news >正文

CNN代码实战

CNN的原理

从 DNN 到 CNN
(1)卷积层与汇聚
⚫ 深度神经网络 DNN 中,相邻层的所有神经元之间都有连接,这叫全连接;卷积神经网络 CNN 中,新增了卷积层(Convolution)与汇聚(Pooling)。
⚫ DNN 的全连接层对应 CNN 的卷积层,汇聚是与激活函数类似的附件;单个卷积层的结构是:卷积层-激活函数-(汇聚),其中汇聚可省略。
(2)CNN:专攻多维数据
在深度神经网络 DNN 课程的最后一章,使用 DNN 进行了手写数字的识别。但是,图像至少就有二维,向全连接层输入时,需要多维数据拉平为 1 维数据,这样一来,图像的形状就被忽视了,很多特征是隐藏在空间属性里的,而卷积层可以保持输入数据的维数不变,当输入数据是二维图像时,卷积层会以多维数据的形式接收输入数据,并同样以多维数据的形式输出至下一层

导包

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

制作数据集

# 制作数据集
# 数据集转换参数
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081)
])
# 下载训练集与测试集
train_Data = datasets.MNIST(
root = 'D:/Postgraduate/CNN', # 下载路径
train = True, # 是 train 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
test_Data = datasets.MNIST(
root = 'D:/Postgraduate/CNN', # 下载路径
train = False, # 是 test 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
# 批次加载器
train_loader = DataLoader(train_Data, shuffle=True, batch_size=256)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=256)

训练网络

class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 120, kernel_size=5), nn.Tanh(),nn.Flatten(),nn.Linear(120, 84), nn.Tanh(),nn.Linear(84, 10)
)def forward(self, x):y = self.net(x)return y
# 创建子类的实例,并搬到 GPU 上
model = CNN().to('cuda:0')
# 训练网络
# 损失函数的选择
loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数
# 优化算法的选择
learning_rate = 0.9 # 设置学习率
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate,
)
# 训练网络
epochs = 5
losses = [] # 记录损失函数变化的列表
for epoch in range(epochs):for (x, y) in train_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)loss = loss_fn(Pred, y) # 计算损失函数losses.append(loss.item()) # 记录损失函数的变化optimizer.zero_grad() # 清理上一轮滞留的梯度loss.backward() # 一次反向传播optimizer.step() # 优化内部参数
Fig = plt.figure()
plt.plot(range(len(losses)), losses)
plt.show()

测试网络

# 测试网络
correct = 0
total = 0
with torch.no_grad(): # 该局部关闭梯度计算功能for (x, y) in test_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum( (predicted == y) )total += y.size(0)
print(f'测试集精准度: {100*correct/total} %')

使用网络

# 保存网络
torch.save(model, 'CNN.path')
new_model = torch.load('CNN.path')

完整代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt# 制作数据集
# 数据集转换参数
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.1307, 0.3081)
])
# 下载训练集与测试集
train_Data = datasets.MNIST(
root = 'D:/Postgraduate/python_project/CNN', # 下载路径
train = True, # 是 train 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
test_Data = datasets.MNIST(
root = 'D:/Postgraduate/python_project/CNN', # 下载路径
train = False, # 是 test 集
download = True, # 如果该路径没有该数据集,就下载
transform = transform # 数据集转换参数
)
# 批次加载器
train_loader = DataLoader(train_Data, shuffle=True, batch_size=256)
test_loader = DataLoader(test_Data, shuffle=False, batch_size=256)class CNN(nn.Module):def __init__(self):super(CNN,self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Tanh(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(16, 120, kernel_size=5), nn.Tanh(),nn.Flatten(),nn.Linear(120, 84), nn.Tanh(),nn.Linear(84, 10)
)def forward(self, x):y = self.net(x)return y
# 创建子类的实例,并搬到 GPU 上
model = CNN().to('cuda:0')
# 训练网络
# 损失函数的选择
loss_fn = nn.CrossEntropyLoss() # 自带 softmax 激活函数
# 优化算法的选择
learning_rate = 0.9 # 设置学习率
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate,
)
# 训练网络
epochs = 5
losses = [] # 记录损失函数变化的列表
for epoch in range(epochs):for (x, y) in train_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)loss = loss_fn(Pred, y) # 计算损失函数losses.append(loss.item()) # 记录损失函数的变化optimizer.zero_grad() # 清理上一轮滞留的梯度loss.backward() # 一次反向传播optimizer.step() # 优化内部参数
Fig = plt.figure()
plt.plot(range(len(losses)), losses)
plt.show()# 测试网络
correct = 0
total = 0
with torch.no_grad(): # 该局部关闭梯度计算功能for (x, y) in test_loader: # 获取小批次的 x 与 yx, y = x.to('cuda:0'), y.to('cuda:0')Pred = model(x) # 一次前向传播(小批量)_, predicted = torch.max(Pred.data, dim=1)correct += torch.sum( (predicted == y) )total += y.size(0)
print(f'测试集精准度: {100*correct/total} %')# 保存网络
torch.save(model, 'CNN.path')
new_model = torch.load('CNN.path')

运行截图

相关文章:

CNN代码实战

CNN的原理 从 DNN 到 CNN (1)卷积层与汇聚 ⚫ 深度神经网络 DNN 中,相邻层的所有神经元之间都有连接,这叫全连接;卷积神经网络 CNN 中,新增了卷积层(Convolution)与汇聚&#xff08…...

迁移学习代码复现

一、前言 说来可能令人难以置信,迁移学习技术在实践中是非常简单的,我们仅需要保留训练好的神经网络整体或者部分网络,再在使用迁移学习的情况下把保留的模型重新加载到内存中,就完成了迁移的过程。之后,我们就可以像训练普通神经网络那样训练迁移过来的神经网络了。 我们…...

Elasticsearch(ES)常用命令

常用运维命令 一、基本命令1.1、查看集群的健康状态1.2、查看节点信息1.3、查看索引列表1.4、创建索引1.5、删除索引1.6、关闭索引1.7、打开索引1.8、查看集群资源使用情况(各个节点的状态,包括磁盘,heap,ram的使用情况&#xff0…...

C/C++ 不定参函数

C语言不定参函数 函数用法总结 Va_list 作用:类型定义,生命一个变量,该变量被用来访问传递给不定参函数的可变参数列表用法:供后续函数进调用,通过该变量访问参数列表 typedefchar* va_list; va_start 作用&#xff…...

C语言——函数专题

1.概念 在C语言中引入函数的概念,有些翻译为子程序。C语言中的函数就是一个完成某项特定任务的一小段代码,这个代码是有特殊的写法和调用方法的。一般我们可以分为两种函数:库函数和自定义函数。 2.库函数 C语言国际标准ANSIC规定了一些常…...

springboot打可执行jar包

1. pom文件如下 <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><m…...

【SQL】科目种类

目录 题目 分析 代码 题目 表: Teacher ------------------- | Column Name | Type | ------------------- | teacher_id | int | | subject_id | int | | dept_id | int | ------------------- 在 SQL 中&#xff0c;(subject_id, dept_id) 是该表的主键。 该表…...

【深度学习】【语音】TTS,最新TTS模型概览,扩散模型TTS,MeloTTS、StyleTTS2、Matcha-TTS

文章目录 基础介绍对比基础介绍 MeloTTS: MeloTTS 是 MyShell.ai 开发的一个多语言语音合成模型,支持包括英语、西班牙语、法语、中文、日语和韩语等多种语言。它以高质量的语音合成为特色,尤其擅长处理中英混合内容。该模型优化了在 CPU 上的实时推理能力,使其在多种应用场…...

【论文笔记】LION: Linear Group RNN for 3D Object Detection in Point Clouds

原文链接&#xff1a;https://arxiv.org/abs/2407.18232 简介&#xff1a;Transformer在3D点云感知任务中有二次复杂度&#xff0c;难以进行长距离关系建模。线性RNN则计算复杂度较低&#xff0c;适合进行长距离关系建模。本文提出基于窗口的网络线性组RNN&#xff08;即对分组…...

打造高可用集群的基石:深度解析Keepalived实践与优化

高可用集群 集群类型 集群类型主要分为负载均衡集群&#xff08;LB&#xff09;、高可用集群&#xff08;HA&#xff09;和高性能计算集群&#xff08;HPC&#xff09;三大类。每种集群类型都有其特定的应用场景和优势。 1. 负载均衡集群&#xff08;LB&#xff09; 负载均衡集…...

Web大学生网页作业成品——环保主题介绍网页网站设计与实现(HTML+CSS)(5个页面)

&#x1f389;&#x1f389;&#x1f389; 常见网页设计作业题材有**汽车、环保、明星、文化、国家、抗疫、景点、人物、体育、植物、公益、图书、节日、游戏、商城、旅游、家乡、学校、电影、动漫、非遗、动物、个人、企业、美食、婚纱、其他**等网页设计题目, 可满足大学生网…...

Qt登录窗口设计

widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QIcon> //图标类 #include <QPushButton> #include <QLineEdit> //行编辑 #include <QLabel> #include <QTextEdit> #include <QMovie>class Widge…...

探索数据矿藏:我的AI大模型与数据挖掘实战经验分享

&#x1f680; 探索数据矿藏&#xff1a;我的AI大模型与数据挖掘实战经验分享 &#x1f496; 前言&#xff1a;数据的金矿&#xff0c;AI的翅膀 在人工智能的浪潮中&#xff0c;我有幸作为项目负责人&#xff0c;带领团队深入挖掘数据的潜力&#xff0c;利用AI大模型的力量&am…...

linux C语言strcat函数及相关函数

讯飞星火 strcat函数是C语言中的一个字符串处理函数,用于将一个字符串追加到另一个字符串的末尾。它的原型如下: char *strcat(char *dest, const char *src);参数说明: dest:目标字符串,即要将src字符串追加到这个字符串的末尾。src:源字符串,即要追加到dest字符串末…...

使用 sort 进行文本文件处理

sort 使用经验笔记 一、简介 sort 是一个命令行工具&#xff0c;用于对文本文件中的行进行排序。它广泛应用于 Unix/Linux 系统中&#xff0c;可以用来对文件的内容进行简单的排序操作&#xff0c;也可以与其他命令结合使用来完成更复杂的任务。 二、基本用法 排序文件: sor…...

HarmonyOS笔记4:从云数据库获取数据

移动应用获取数据的方式主要有&#xff1a; 1.从网络中获取数据接口API。 2.从华为云数据库获取云数据库的资源。 3.从移动终端直接获取本地的数据 在HarmonyOS笔记3中已经完成了方式一从网络中获取数据接口API的方式。在本篇笔记中&#xff0c;将讨论从云数据库中获取数据。 因…...

QT5生成独立运行的exe文件

目录 1 生成独立运行的exe文件1.1 设置工程Release版本可执行文件存储路径1.2 将工程编译成Release版本 2 使用QT5自带的windeployqt拷贝软件运行依赖项3 将程序打包成一个独立的可执行软件exe4 解决QT5 This application failed to start because no Qt platform plugin could…...

LabVIEW光纤水听器闭环系统

开发了一种利用LabVIEW软件开发的干涉型光纤水听器闭环工作点控制系统。该系统通过调节光源频率和非平衡干涉仪的光程差&#xff0c;实现了工作点的精确控制&#xff0c;从而提高系统的稳定性和检测精度&#xff0c;避免了使用压电陶瓷&#xff0c;使操作更加简便。 项目背景 …...

Shell——流程控制语句(if、case、for、while等)

在 Shell 编程中&#xff0c;流程控制语句用于控制脚本的执行顺序和逻辑。这些语句包括 if、case、for、while 等&#xff0c;它们的使用可以使脚本实现更复杂的逻辑。以下是它们的详细说明和语法结构&#xff1a; 1. if 语句 if 语句用于条件判断&#xff0c;执行符合条件的…...

【redis的大key问题】

在使用 Redis 的过程中&#xff0c;如果未能及时发现并处理 Big keys&#xff08;下文称为“大Key”&#xff09;&#xff0c;可能会导致服务性能下降、用户体验变差&#xff0c;甚至引发大面积故障。 本文将介绍大Key产生的原因、其可能引发的问题及如何快速找出大Key并将其优…...

多模态2025:技术路线“神仙打架”,视频生成冲上云霄

文&#xff5c;魏琳华 编&#xff5c;王一粟 一场大会&#xff0c;聚集了中国多模态大模型的“半壁江山”。 智源大会2025为期两天的论坛中&#xff0c;汇集了学界、创业公司和大厂等三方的热门选手&#xff0c;关于多模态的集中讨论达到了前所未有的热度。其中&#xff0c;…...

k8s从入门到放弃之Ingress七层负载

k8s从入门到放弃之Ingress七层负载 在Kubernetes&#xff08;简称K8s&#xff09;中&#xff0c;Ingress是一个API对象&#xff0c;它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress&#xff0c;你可…...

通过Wrangler CLI在worker中创建数据库和表

官方使用文档&#xff1a;Getting started Cloudflare D1 docs 创建数据库 在命令行中执行完成之后&#xff0c;会在本地和远程创建数据库&#xff1a; npx wranglerlatest d1 create prod-d1-tutorial 在cf中就可以看到数据库&#xff1a; 现在&#xff0c;您的Cloudfla…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

为什么需要建设工程项目管理?工程项目管理有哪些亮点功能?

在建筑行业&#xff0c;项目管理的重要性不言而喻。随着工程规模的扩大、技术复杂度的提升&#xff0c;传统的管理模式已经难以满足现代工程的需求。过去&#xff0c;许多企业依赖手工记录、口头沟通和分散的信息管理&#xff0c;导致效率低下、成本失控、风险频发。例如&#…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

Python爬虫(一):爬虫伪装

一、网站防爬机制概述 在当今互联网环境中&#xff0c;具有一定规模或盈利性质的网站几乎都实施了各种防爬措施。这些措施主要分为两大类&#xff1a; 身份验证机制&#xff1a;直接将未经授权的爬虫阻挡在外反爬技术体系&#xff1a;通过各种技术手段增加爬虫获取数据的难度…...

让AI看见世界:MCP协议与服务器的工作原理

让AI看见世界&#xff1a;MCP协议与服务器的工作原理 MCP&#xff08;Model Context Protocol&#xff09;是一种创新的通信协议&#xff0c;旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天&#xff0c;MCP正成为连接AI与现实世界的重要桥梁。…...

AGain DB和倍数增益的关系

我在设置一款索尼CMOS芯片时&#xff0c;Again增益0db变化为6DB&#xff0c;画面的变化只有2倍DN的增益&#xff0c;比如10变为20。 这与dB和线性增益的关系以及传感器处理流程有关。以下是具体原因分析&#xff1a; 1. dB与线性增益的换算关系 6dB对应的理论线性增益应为&…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP&#xff0c;结果IP质量不佳&#xff0c;项目效率低下不说&#xff0c;还可能带来莫名的网络问题&#xff0c;是不是太闹心了&#xff1f;尤其是在面对海外专线IP时&#xff0c;到底怎么才能买到适合自己的呢&#xff1f;所以&#xff0c;挑IP绝对是个技…...