当前位置: 首页 > news >正文

卷积神经网络(CNN)简单原理与简单代码实现

卷积神经网络(CNN)简单原理与简单代码实现

  • 卷积神经网络(CNN)简单原理
    • 基本原理
      • 卷积层(Convolutional Layer):
      • 激活层(Activation Layer):
      • 池化层(Pooling Layer):
      • 全连接层(Fully Connected Layer):
        • 主要特点
      • 简单代码实现

卷积神经网络(CNN)简单原理

卷积神经网络(Convolutional Neural Network, CNN)是深度学习领域中一种重要的网络结构,特别适用于处理具有网格结构的数据,如图像。其基本原理和主要特点如下:

基本原理

卷积层(Convolutional Layer):

卷积层是CNN的核心,它通过卷积操作从输入数据中提取特征。
卷积操作是使用一个或多个可学习的滤波器(或称卷积核)在输入数据上滑动,计算滤波器与输入数据对应区域的点积,生成特征图(Feature Map)。
每个滤波器都可以提取输入数据的一种特征,因此,使用多个滤波器可以提取多种特征。

激活层(Activation Layer):

激活层通常紧跟在卷积层之后,用于增加模型的非线性。
常用的激活函数包括ReLU(Rectified Linear Unit)、Sigmoid、Tanh等。其中,ReLU函数因其简单有效而被广泛使用。

池化层(Pooling Layer):

池化层用于降低特征图的维度,减少计算量,并增强模型的鲁棒性。
常见的池化操作有最大池化(Max Pooling)和平均池化(Average Pooling)。最大池化取池化区域内的最大值作为输出,而平均池化则取平均值。

全连接层(Fully Connected Layer):

全连接层通常位于CNN的末端,用于将前面提取到的特征映射到样本的类别上。
在全连接层中,每个神经元都与前一层的所有神经元相连。

主要特点

局部连接:卷积层中的每个神经元仅与输入数据的一个局部区域相连,这有助于捕捉图像的局部特征。
参数共享:同一个卷积核在输入数据的不同位置共享相同的参数,这大大减少了模型的参数数量。
平移不变性:由于池化层的存在,CNN对输入数据的平移变换具有一定的不变性。

简单代码实现

以下是一个使用PyTorch框架实现的简单CNN模型,用于手写数字识别(MNIST数据集):

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import torchvision  
import torchvision.transforms as transforms  # 数据预处理  
transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.5,), (0.5,))  
])  trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)  testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)  
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)  # 定义CNN模型  
class ConvNet(nn.Module):  def __init__(self):  super(ConvNet, self).__init__()  self.conv1 = nn.Conv2d(1, 10, 5)  # 输入通道数为1,输出通道数为10,卷积核大小为5x5  self.pool = nn.MaxPool2d(2, 2)    # 池化窗口大小为2x2,步长为2  self.conv2 = nn.Conv2d(10, 20, 5) # 输入通道数为10,输出通道数为20,卷积核大小为5x5  self.fc = nn.Linear(320, 10)      # 全连接层,输入特征维度为320,输出类别数为10  def forward(self, x):  x = self.pool(F.relu(self.conv1(x)))  x = self.pool(F.relu(self.conv2(x)))  x = x.view(-1, 320)  # 展平操作  x = self.fc(x)  return x  # 初始化模型、损失函数和优化器  
model = ConvNet()  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # 训练模型  
for epoch in range(10):  # 假设训练10个epoch  running_loss = 0.0  for i, data in enumerate(trainloader, 0):  inputs, labels = data  optimizer.zero_grad()  outputs = model(inputs)  loss = criterion(outputs, labels)  loss.backward()  optimizer.step()  running_loss += loss.item()  print(f'Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}')  # 测试模型(代码略,通常包括关闭梯度计算、遍历测试集、计算准确率等步骤)

这段代码首先定义了数据预处理步骤,然后定义了一个简单的CNN模型,该模型包含两个卷积层、两个池化层和一个全连接层。接着,初始化了模型、损失函数和优化器,并展示了训练模型的基本流程。需要注意的是,测试模型的代码部分在这里被省略了,但通常包括关闭梯度计算、遍历测试集、计算模型输出与真实标签之间的损失或准确率等步骤。

在TensorFlow中实现一个简单的卷积神经网络(CNN)通常涉及以下几个步骤:定义模型结构、编译模型、训练模型以及评估模型。以下是一个使用TensorFlow 2(及其高级API Keras)实现的简单CNN示例,该示例用于手写数字识别(MNIST数据集)。

import tensorflow as tf  
from tensorflow.keras import datasets, layers, models  
import numpy as np  # 加载并预处理数据  
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()  # 将图像从整数转换为浮点数,并归一化到0到1的范围内  
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255  
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255  # 定义模型结构  
model = models.Sequential()  
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))  
model.add(layers.MaxPooling2D((2, 2)))  
model.add(layers.Conv2D(64, (3, 3), activation='relu'))  
model.add(layers.MaxPooling2D((2, 2)))  
model.add(layers.Conv2D(64, (3, 3), activation='relu'))  
model.add(layers.Flatten())  
model.add(layers.Dense(64, activation='relu'))  
model.add(layers.Dense(10))  # 编译模型  
model.compile(optimizer='adam',  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),  metrics=['accuracy'])  # 训练模型  
history = model.fit(train_images, train_labels, epochs=10,   validation_data=(test_images, test_labels))  # 评估模型  
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)  
print('\nTest accuracy:', test_acc)

相关文章:

卷积神经网络(CNN)简单原理与简单代码实现

卷积神经网络(CNN)简单原理与简单代码实现 卷积神经网络(CNN)简单原理基本原理卷积层(Convolutional Layer):激活层(Activation Layer):池化层(Po…...

实时数仓分层架构详解

首先,我们从数据仓库说起。 数据仓库的概念可以追溯到20世纪80年代,当时IBM的研究人员提出了商业数据仓库的概念。数据仓库概念的提出,是为了解决和数据流相关的各种问题,特别是多重数据复制带来的高成本问题。 数据仓库之父Bill …...

计算机“八股文”在实际工作中是助力、阻力还是空谈?

“八股文”在实际工作中是助力、阻力还是空谈? 作为现在各类大中小企业面试程序员时的必问内容,“八股文”似乎是很重要的存在。但“八股文”是否能在实际工作中发挥它“敲门砖”应有的作用呢?有IT人士不禁发出疑问:程序员面试考…...

新160个crackme - 022-CM_2

运行分析 需破解Name和Serial,输入的小写字母都会变为大写字母 PE分析 C程序,32位,无壳 静态分析&动态调试 发现关键字符串 ida动态调试,发现Name和Serial长度需要大于5,且Serial前6位明文爆出,6287-A …...

在.c和.h 文件里定义数组的区别

在C语言开发中,掌握如何在.c文件和.h文件中合理定义数组,对于维护代码的模块化和避免不必要的编译错误至关重要。本文将探讨在这两种类型的文件中定义数组时需要注意的几个关键方面,包括定义性质、作用域、重复定义问题以及外部可见性等&…...

使用Step Functions运行AWS Backup时必备的权限要点

引言 在尝试从Step Functions执行AWS Backup的按需备份时,我在权限方面遇到了一些困难。为了备忘,我将这些经验写成这篇文章。 概述 从Step Functions执行AWS Backup时,需要分配以下权限: AWS Backup相关权限 执行备份的权限…...

强化JS基础水平的10个单行代码来喽!(必看)

目录 生成数组 数组简单数据去重 多数组取交集 重新加载当前页面 滚动到页面顶部 查找最大值索引 进制转换 文本粘贴 删除无效属性 随机颜色生成 生成数组 当你需要要生成一个0-99的数组 // 生成一个0-99的数组 // 方案一 const createArr n > Array.from(new A…...

大模型学习笔记 - 大纲

LLM 大纲 LLM 大纲 1. LLM 模型架构 LLM 技术细节 - 注意力机制LLM 技术细节 - 位置编码 2. LLM 预训练3. LLM 指令微调 LLM 高效微调技术 4. LLM 人类对齐 LLM InstructGPTLLM PPO算法LLM DPO 算法 5. LLM 解码与部署6. LLM 模型LLaMA 系列7. LLM RAG 1. LLM 模型架构 大模…...

苹果电脑可以玩什么小游戏 适合Mac电脑玩的休闲游戏推荐

对于游戏爱好者而言,Mac似乎并不是游戏体验的首选平台。这主要是因为相较于Windows系统,Mac上的游戏资源显得相对有限。不过,这并不意味着Mac用户就与游戏世界绝缘。实际上,Mac平台上有着一系列小巧精致且趣味横生的小游戏&#x…...

浅谈KMP算法(c++)

目录 前缀函数应用【模板】KMP题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示样例 1 解释数据规模与约定 思路AC代码 本质不同子串数 例题讲解[NOI2014] 动物园题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示思路AC代码 [POI2006] OKR-Periods of …...

关于C++编程注意点(竞赛)

赛前准备 多复习 重中之重, 多刷题 确保手感 参加几场模拟赛,熟悉流程 熟悉 Linux 系统,否则你将会手忙脚乱 放松心情,调整心态,分数 实力 心态 赛中注意 输入输出方面 在数据范围超过 时尽量使用 scanf pr…...

Markdown文本编辑器:Typora for Mac/win 中文版

Markdown 是一种轻量级的标记语言,它允许用户使用易读易写的纯文本格式编写文档。Typora 支持且仅支持 Markdown 语法的文本编辑,生成的文档后缀名为 .md。 这款软件的特点包括: 实时预览:Typora 的一个显著特点是实时预览&#x…...

Mysql-窗口函数一

文章目录 1. 窗口函数概述1.1 介绍1.2 作用 2. 场景说明2.1 准备工作2.2 场景说明2.3 分析2.4 实现2.4.1 非窗口函数方式实现2.4.2 窗口函数方式实现 3. 窗口函数分类4. 窗口函数基础用法:OVER关键字4.1 语法4.2 场景一 :计算每个值和整体平均值的差值4.2.1 需求4.2…...

Python3 爬虫 数据抓包

一、数据抓包 所谓抓包(Package Capture),简单来说,就是在网络数据传输的过程中对数据包进行截获、查看、修改或转发的过程。如果把网络上发送与接收的数据包理解为快递包裹,那么在快递运输的过程中查看里面的内容&…...

js强制刷新

在JavaScript中触发强制刷新通常指的是强制浏览器重新加载页面,忽略缓存。以下是几种实现强制刷新的方法: ### 使用 location.reload() 这是最简单的方法,它会重新加载当前页面。 javascript location.reload(true); // 传入true参数表示强制…...

yolov5 part2

two-stage (两阶段):Faster-rcnn Mask-Rcnn系列 one-stage (单阶段):YOLO系列 最核心的优势:速度非常快,适合实时监测任务。但是缺点也有,效果可能不好 速度较慢在2018…...

Hive3:表操作常用语句-内部表、外部表

一、内部表 1、基本介绍 (CREATE TABLE table_name ......) 未被external关键字修饰的即是内部表, 即普通表。 内部表又称管理表,内部表数据存储的位置由hive.metastore.warehouse.dir参数决定(默认:/user/hive/ware…...

【PXE+kickstart】linux网络服务之自动装机

PXE: 简介:PXE(Preboot execute environment 是一种能够让计算机通过网络启动的引导方式,只要网卡支持PXE协议即可使用Kickstart 是一种无人值守的安装方式,工作原理就是预先把原本需要运维人员手工填写的参数保存成一个 ks.cfg 文…...

vmware ubuntu虚拟机网络联网配置

介绍vmware虚拟机配置基础网络环境,同时连接外网(通过桥接模式),以及ubuntu下输入法等基础工具安装。 本文基于ubuntu22.04,前提虚拟机已经完成安装。本文更多是针对vmware虚拟机的设置,之前有一篇针对ubun…...

Vue3_对接声网实时音视频_多人视频会议

目录 一、声网 1.注册账号 2.新建项目 二、实时音视频集成 1.声网CDN集成 2.iframe嵌入html 3.自定义UI集成 4.提高进入房间速度 web项目需要实现一个多人会议,对接的声网的灵动课堂。在这里说一下对接流程。 一、声网 声网成立于2014年,是全球…...

[特殊字符] 智能合约中的数据是如何在区块链中保持一致的?

🧠 智能合约中的数据是如何在区块链中保持一致的? 为什么所有区块链节点都能得出相同结果?合约调用这么复杂,状态真能保持一致吗?本篇带你从底层视角理解“状态一致性”的真相。 一、智能合约的数据存储在哪里&#xf…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...

Redis相关知识总结(缓存雪崩,缓存穿透,缓存击穿,Redis实现分布式锁,如何保持数据库和缓存一致)

文章目录 1.什么是Redis?2.为什么要使用redis作为mysql的缓存?3.什么是缓存雪崩、缓存穿透、缓存击穿?3.1缓存雪崩3.1.1 大量缓存同时过期3.1.2 Redis宕机 3.2 缓存击穿3.3 缓存穿透3.4 总结 4. 数据库和缓存如何保持一致性5. Redis实现分布式…...

STM32+rt-thread判断是否联网

一、根据NETDEV_FLAG_INTERNET_UP位判断 static bool is_conncected(void) {struct netdev *dev RT_NULL;dev netdev_get_first_by_flags(NETDEV_FLAG_INTERNET_UP);if (dev RT_NULL){printf("wait netdev internet up...");return false;}else{printf("loc…...

AtCoder 第409​场初级竞赛 A~E题解

A Conflict 【题目链接】 原题链接:A - Conflict 【考点】 枚举 【题目大意】 找到是否有两人都想要的物品。 【解析】 遍历两端字符串,只有在同时为 o 时输出 Yes 并结束程序,否则输出 No。 【难度】 GESP三级 【代码参考】 #i…...

条件运算符

C中的三目运算符(也称条件运算符,英文:ternary operator)是一种简洁的条件选择语句,语法如下: 条件表达式 ? 表达式1 : 表达式2• 如果“条件表达式”为true,则整个表达式的结果为“表达式1”…...

将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?

Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数

高效线程安全的单例模式:Python 中的懒加载与自定义初始化参数 在软件开发中,单例模式(Singleton Pattern)是一种常见的设计模式,确保一个类仅有一个实例,并提供一个全局访问点。在多线程环境下,实现单例模式时需要注意线程安全问题,以防止多个线程同时创建实例,导致…...

Spring是如何解决Bean的循环依赖:三级缓存机制

1、什么是 Bean 的循环依赖 在 Spring框架中,Bean 的循环依赖是指多个 Bean 之间‌互相持有对方引用‌,形成闭环依赖关系的现象。 多个 Bean 的依赖关系构成环形链路,例如: 双向依赖:Bean A 依赖 Bean B,同时 Bean B 也依赖 Bean A(A↔B)。链条循环: Bean A → Bean…...