当前位置: 首页 > news >正文

卷积神经网络(CNN)简单原理与简单代码实现

卷积神经网络(CNN)简单原理与简单代码实现

  • 卷积神经网络(CNN)简单原理
    • 基本原理
      • 卷积层(Convolutional Layer):
      • 激活层(Activation Layer):
      • 池化层(Pooling Layer):
      • 全连接层(Fully Connected Layer):
        • 主要特点
      • 简单代码实现

卷积神经网络(CNN)简单原理

卷积神经网络(Convolutional Neural Network, CNN)是深度学习领域中一种重要的网络结构,特别适用于处理具有网格结构的数据,如图像。其基本原理和主要特点如下:

基本原理

卷积层(Convolutional Layer):

卷积层是CNN的核心,它通过卷积操作从输入数据中提取特征。
卷积操作是使用一个或多个可学习的滤波器(或称卷积核)在输入数据上滑动,计算滤波器与输入数据对应区域的点积,生成特征图(Feature Map)。
每个滤波器都可以提取输入数据的一种特征,因此,使用多个滤波器可以提取多种特征。

激活层(Activation Layer):

激活层通常紧跟在卷积层之后,用于增加模型的非线性。
常用的激活函数包括ReLU(Rectified Linear Unit)、Sigmoid、Tanh等。其中,ReLU函数因其简单有效而被广泛使用。

池化层(Pooling Layer):

池化层用于降低特征图的维度,减少计算量,并增强模型的鲁棒性。
常见的池化操作有最大池化(Max Pooling)和平均池化(Average Pooling)。最大池化取池化区域内的最大值作为输出,而平均池化则取平均值。

全连接层(Fully Connected Layer):

全连接层通常位于CNN的末端,用于将前面提取到的特征映射到样本的类别上。
在全连接层中,每个神经元都与前一层的所有神经元相连。

主要特点

局部连接:卷积层中的每个神经元仅与输入数据的一个局部区域相连,这有助于捕捉图像的局部特征。
参数共享:同一个卷积核在输入数据的不同位置共享相同的参数,这大大减少了模型的参数数量。
平移不变性:由于池化层的存在,CNN对输入数据的平移变换具有一定的不变性。

简单代码实现

以下是一个使用PyTorch框架实现的简单CNN模型,用于手写数字识别(MNIST数据集):

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
import torchvision  
import torchvision.transforms as transforms  # 数据预处理  
transform = transforms.Compose([  transforms.ToTensor(),  transforms.Normalize((0.5,), (0.5,))  
])  trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)  testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)  
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)  # 定义CNN模型  
class ConvNet(nn.Module):  def __init__(self):  super(ConvNet, self).__init__()  self.conv1 = nn.Conv2d(1, 10, 5)  # 输入通道数为1,输出通道数为10,卷积核大小为5x5  self.pool = nn.MaxPool2d(2, 2)    # 池化窗口大小为2x2,步长为2  self.conv2 = nn.Conv2d(10, 20, 5) # 输入通道数为10,输出通道数为20,卷积核大小为5x5  self.fc = nn.Linear(320, 10)      # 全连接层,输入特征维度为320,输出类别数为10  def forward(self, x):  x = self.pool(F.relu(self.conv1(x)))  x = self.pool(F.relu(self.conv2(x)))  x = x.view(-1, 320)  # 展平操作  x = self.fc(x)  return x  # 初始化模型、损失函数和优化器  
model = ConvNet()  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)  # 训练模型  
for epoch in range(10):  # 假设训练10个epoch  running_loss = 0.0  for i, data in enumerate(trainloader, 0):  inputs, labels = data  optimizer.zero_grad()  outputs = model(inputs)  loss = criterion(outputs, labels)  loss.backward()  optimizer.step()  running_loss += loss.item()  print(f'Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}')  # 测试模型(代码略,通常包括关闭梯度计算、遍历测试集、计算准确率等步骤)

这段代码首先定义了数据预处理步骤,然后定义了一个简单的CNN模型,该模型包含两个卷积层、两个池化层和一个全连接层。接着,初始化了模型、损失函数和优化器,并展示了训练模型的基本流程。需要注意的是,测试模型的代码部分在这里被省略了,但通常包括关闭梯度计算、遍历测试集、计算模型输出与真实标签之间的损失或准确率等步骤。

在TensorFlow中实现一个简单的卷积神经网络(CNN)通常涉及以下几个步骤:定义模型结构、编译模型、训练模型以及评估模型。以下是一个使用TensorFlow 2(及其高级API Keras)实现的简单CNN示例,该示例用于手写数字识别(MNIST数据集)。

import tensorflow as tf  
from tensorflow.keras import datasets, layers, models  
import numpy as np  # 加载并预处理数据  
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()  # 将图像从整数转换为浮点数,并归一化到0到1的范围内  
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255  
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255  # 定义模型结构  
model = models.Sequential()  
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))  
model.add(layers.MaxPooling2D((2, 2)))  
model.add(layers.Conv2D(64, (3, 3), activation='relu'))  
model.add(layers.MaxPooling2D((2, 2)))  
model.add(layers.Conv2D(64, (3, 3), activation='relu'))  
model.add(layers.Flatten())  
model.add(layers.Dense(64, activation='relu'))  
model.add(layers.Dense(10))  # 编译模型  
model.compile(optimizer='adam',  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),  metrics=['accuracy'])  # 训练模型  
history = model.fit(train_images, train_labels, epochs=10,   validation_data=(test_images, test_labels))  # 评估模型  
test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)  
print('\nTest accuracy:', test_acc)

相关文章:

卷积神经网络(CNN)简单原理与简单代码实现

卷积神经网络(CNN)简单原理与简单代码实现 卷积神经网络(CNN)简单原理基本原理卷积层(Convolutional Layer):激活层(Activation Layer):池化层(Po…...

实时数仓分层架构详解

首先,我们从数据仓库说起。 数据仓库的概念可以追溯到20世纪80年代,当时IBM的研究人员提出了商业数据仓库的概念。数据仓库概念的提出,是为了解决和数据流相关的各种问题,特别是多重数据复制带来的高成本问题。 数据仓库之父Bill …...

计算机“八股文”在实际工作中是助力、阻力还是空谈?

“八股文”在实际工作中是助力、阻力还是空谈? 作为现在各类大中小企业面试程序员时的必问内容,“八股文”似乎是很重要的存在。但“八股文”是否能在实际工作中发挥它“敲门砖”应有的作用呢?有IT人士不禁发出疑问:程序员面试考…...

新160个crackme - 022-CM_2

运行分析 需破解Name和Serial,输入的小写字母都会变为大写字母 PE分析 C程序,32位,无壳 静态分析&动态调试 发现关键字符串 ida动态调试,发现Name和Serial长度需要大于5,且Serial前6位明文爆出,6287-A …...

在.c和.h 文件里定义数组的区别

在C语言开发中,掌握如何在.c文件和.h文件中合理定义数组,对于维护代码的模块化和避免不必要的编译错误至关重要。本文将探讨在这两种类型的文件中定义数组时需要注意的几个关键方面,包括定义性质、作用域、重复定义问题以及外部可见性等&…...

使用Step Functions运行AWS Backup时必备的权限要点

引言 在尝试从Step Functions执行AWS Backup的按需备份时,我在权限方面遇到了一些困难。为了备忘,我将这些经验写成这篇文章。 概述 从Step Functions执行AWS Backup时,需要分配以下权限: AWS Backup相关权限 执行备份的权限…...

强化JS基础水平的10个单行代码来喽!(必看)

目录 生成数组 数组简单数据去重 多数组取交集 重新加载当前页面 滚动到页面顶部 查找最大值索引 进制转换 文本粘贴 删除无效属性 随机颜色生成 生成数组 当你需要要生成一个0-99的数组 // 生成一个0-99的数组 // 方案一 const createArr n > Array.from(new A…...

大模型学习笔记 - 大纲

LLM 大纲 LLM 大纲 1. LLM 模型架构 LLM 技术细节 - 注意力机制LLM 技术细节 - 位置编码 2. LLM 预训练3. LLM 指令微调 LLM 高效微调技术 4. LLM 人类对齐 LLM InstructGPTLLM PPO算法LLM DPO 算法 5. LLM 解码与部署6. LLM 模型LLaMA 系列7. LLM RAG 1. LLM 模型架构 大模…...

苹果电脑可以玩什么小游戏 适合Mac电脑玩的休闲游戏推荐

对于游戏爱好者而言,Mac似乎并不是游戏体验的首选平台。这主要是因为相较于Windows系统,Mac上的游戏资源显得相对有限。不过,这并不意味着Mac用户就与游戏世界绝缘。实际上,Mac平台上有着一系列小巧精致且趣味横生的小游戏&#x…...

浅谈KMP算法(c++)

目录 前缀函数应用【模板】KMP题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示样例 1 解释数据规模与约定 思路AC代码 本质不同子串数 例题讲解[NOI2014] 动物园题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示思路AC代码 [POI2006] OKR-Periods of …...

关于C++编程注意点(竞赛)

赛前准备 多复习 重中之重, 多刷题 确保手感 参加几场模拟赛,熟悉流程 熟悉 Linux 系统,否则你将会手忙脚乱 放松心情,调整心态,分数 实力 心态 赛中注意 输入输出方面 在数据范围超过 时尽量使用 scanf pr…...

Markdown文本编辑器:Typora for Mac/win 中文版

Markdown 是一种轻量级的标记语言,它允许用户使用易读易写的纯文本格式编写文档。Typora 支持且仅支持 Markdown 语法的文本编辑,生成的文档后缀名为 .md。 这款软件的特点包括: 实时预览:Typora 的一个显著特点是实时预览&#x…...

Mysql-窗口函数一

文章目录 1. 窗口函数概述1.1 介绍1.2 作用 2. 场景说明2.1 准备工作2.2 场景说明2.3 分析2.4 实现2.4.1 非窗口函数方式实现2.4.2 窗口函数方式实现 3. 窗口函数分类4. 窗口函数基础用法:OVER关键字4.1 语法4.2 场景一 :计算每个值和整体平均值的差值4.2.1 需求4.2…...

Python3 爬虫 数据抓包

一、数据抓包 所谓抓包(Package Capture),简单来说,就是在网络数据传输的过程中对数据包进行截获、查看、修改或转发的过程。如果把网络上发送与接收的数据包理解为快递包裹,那么在快递运输的过程中查看里面的内容&…...

js强制刷新

在JavaScript中触发强制刷新通常指的是强制浏览器重新加载页面,忽略缓存。以下是几种实现强制刷新的方法: ### 使用 location.reload() 这是最简单的方法,它会重新加载当前页面。 javascript location.reload(true); // 传入true参数表示强制…...

yolov5 part2

two-stage (两阶段):Faster-rcnn Mask-Rcnn系列 one-stage (单阶段):YOLO系列 最核心的优势:速度非常快,适合实时监测任务。但是缺点也有,效果可能不好 速度较慢在2018…...

Hive3:表操作常用语句-内部表、外部表

一、内部表 1、基本介绍 (CREATE TABLE table_name ......) 未被external关键字修饰的即是内部表, 即普通表。 内部表又称管理表,内部表数据存储的位置由hive.metastore.warehouse.dir参数决定(默认:/user/hive/ware…...

【PXE+kickstart】linux网络服务之自动装机

PXE: 简介:PXE(Preboot execute environment 是一种能够让计算机通过网络启动的引导方式,只要网卡支持PXE协议即可使用Kickstart 是一种无人值守的安装方式,工作原理就是预先把原本需要运维人员手工填写的参数保存成一个 ks.cfg 文…...

vmware ubuntu虚拟机网络联网配置

介绍vmware虚拟机配置基础网络环境,同时连接外网(通过桥接模式),以及ubuntu下输入法等基础工具安装。 本文基于ubuntu22.04,前提虚拟机已经完成安装。本文更多是针对vmware虚拟机的设置,之前有一篇针对ubun…...

Vue3_对接声网实时音视频_多人视频会议

目录 一、声网 1.注册账号 2.新建项目 二、实时音视频集成 1.声网CDN集成 2.iframe嵌入html 3.自定义UI集成 4.提高进入房间速度 web项目需要实现一个多人会议,对接的声网的灵动课堂。在这里说一下对接流程。 一、声网 声网成立于2014年,是全球…...

慧灵科技:创新引领自动化未来

在智能制造与自动化生产日益成为主流趋势的今天,慧灵科技凭借其卓越的技术创新能力和产品优势,在机器人领域崭露头角。 自2015年在深圳成立以来,慧灵科技专注于核心技术的研发与产品创新,‌为各行业提供性价比极高的机器人产品及自…...

【TiDB 社区智慧合集】TiDB 在核心场景的实战应用

作者: 社区小助手 原文来源: https://tidb.net/blog/5cc4ec70 杭州银行 杭州银行采用 TiDB 作为其核心系统数据库,标志着银行资产规模和业务复杂性的大幅增长。通过"分布式透明化"的思考,杭州银行实现了从传统 Orac…...

JetBrains:XML tag has empty body警告

在xml文件中配置时,因为标签内容为空,出现黄色警告影响观感。 通过IDE配置关闭告警...

XMLDecoder反序列化

XMLDecoder反序列化 基础知识 就简单讲讲吧,就是为了解析xml内容的 一般我们的xml都是标签属性这样的写法 比如person对象以xml的形式存储在文件中 在decode反序列化方法后,控制台成功打印出反序列化的对象。 就是可以根据我们的标签识别是什么成分…...

C# 高级数据处理:深入解析数据分区 Join 与 GroupJoin 操作的应用与实例演示

文章目录 一、概述二. 数据分区 (Partitioning)三、Join 操作符1. Join 操作符的基本用法2. Join 操作符示例 四、GroupJoin 操作符1. GroupJoin 操作符的基本用法2. GroupJoin 操作符示例 总结 在数据处理中,联接(Join)操作是一种非常常见的…...

数据库典型例题2-ER图转换关系模型

1.question solution: 2.做题步骤 一些解释&#xff1a; <1弱实体把强属性的主键写进去&#xff0c;指向强属性。eg:E6_A13指向E5_A13 <21:1&#xff0c;1:n&#xff0c;m:n&#xff1a;将完全参与的一方&#xff08;双线&#xff09;指向另一方&#xff0c;并将对方的…...

Java:设计模式(单例,工厂,代理,命令,桥接,观察者)

模式是一条由三部分组成的通用规则&#xff1a;它代表了一个特定环境、一类问题和一个解决方案之间的关系。每一个模式描述了一个不断重复发生的问题&#xff0c;以及该问题解决方案的核心设计。 软件领域的设计模式定义&#xff1a;设计模式是对处于特定环境下&#xff0c;经常…...

【算法】KMP算法

应用场景 有一个字符串 str1 "BBA ABCA ABCDAB ABCDABD"&#xff0c;和一个子串 str2 "ABCDABD"现在要判断 str1 是否含有 str2&#xff0c;如果含有&#xff0c;就返回第一次出现的位置&#xff0c;如果不含有&#xff0c;则返回 -1 我们很容易想到暴力…...

nginx续1:

八、虚拟主机配置 基于域名的虚拟主机 [rootserver2 ~]# ps -au|grep nginx //查看进程 修改Nginx服务配置&#xff0c;添加相关虚拟主机配置如下 1. [rootproxy ~]# vim /usr/local/nginx/conf/nginx.conf 2. .. .. 3. server { 4. listen …...

循环队列和阻塞有什么关系?和生产者消费者模型又有什么关系?阻塞队列和异步日志又有什么关系

### 循环队列和阻塞队列 #### 循环队列 - **定义**: 一个固定大小的数组&#xff0c;通过两个指针&#xff08;front 和 back&#xff09;管理队列的头部和尾部元素。 - **特点**: - **循环性**: 当指针到达数组的末尾时&#xff0c;可以回绕到数组的开头&#xff0c;从而利…...