当前位置: 首页 > news >正文

神经网络构建原理(以MINIST为例)

神经网络构建原理(以MINIST为例)

在 MNIST 手写数字识别任务中,构建神经网络并训练模型来进行分类是经典的深度学习应用。MNIST 数据集包含 28x28 像素的手写数字图像(0-9),任务是构建一个神经网络,能够根据输入的图像预测对应的数字。本文将通过该案例详细介绍神经网络的逻辑框架和具体的计算流程。

神经网络构建框架

1.数据预处理

  • 将输入数据进行标准化处理(归一化),并将标签转换为适合模型的格式( one-hot 编码)。

2.模型构建

  • a.输入层:定义输入的大小(将 28x28 的图像展平为 784 维向量)。
  • b.隐藏层:添加一个或多个隐藏层,每层包含一定数量的神经元,并应用激活函数(如 ReLU)[实际上神经元相当于维数,该过程即是对原特征维数进行扩维或降维]。
  • c.输出层:定义输出层的神经元数量(与分类类别一致),通常使用 Softmax 函数将输出转换为概率。

3.前向传播

  • 执行输入层到隐藏层、再到输出层的矩阵乘法和激活函数,计算输出值。

4.损失函数计算

  • 使用交叉熵等损失函数,计算预测输出与真实标签之间的误差。

5.反向传播

  • 通过链式法则,计算损失函数对每个参数的偏导数,并更新权重和偏置项。

6.优化器更新

  • 使用优化器(如 SGD、Adam)基于计算的梯度更新模型参数,降低损失值。

7.迭代训练

  • 不断重复前向传播、损失计算和反向传播,直到损失收敛或达到设定的训练轮次。

8.模型评估与预测

  • 训练完成后,用测试数据评估模型性能,并进行新数据的预测。
    在这里插入图片描述
    在这里插入图片描述

神经网络的具体计算流程

接下来以MINIST手写数字识别为例,模拟神经网络构建的具体计算过程。

假设该网络包含 两个隐藏层,每个隐藏层有 25 个神经元,最后的输出层为 10 个神经元

1. 前向传播(Forward Propagation)

1.1输入层到第一个隐藏层:
  • 输入大小:假设输入图像是 28x28 的像素矩阵,展平成 784 维的向量。 x ∈ R 784 × 1 x \in \mathbb{R}^{784 \times 1} xR784×1

  • 权重矩阵:连接输入层到第一个隐藏层的权重矩阵,大小为 25x784,[因为特征向量是列向量,所以需要转置]。 W 1 ∈ R 25 × 784 W_1 \in \mathbb{R}^{25 \times 784} W1R25×784

  • 偏置项:第一个隐藏层的偏置项,大小为 25x1。 b 1 ∈ R 25 × 1 b_1 \in \mathbb{R}^{25 \times 1} b1R25×1

  • 激活函数:使用 ReLU 激活函数。

计算步骤

  • 执行矩阵乘法 z 1 = W 1 ⋅ x + b 1 z_1 = W_1 \cdot x + b_1 z1=W1x+b1

z 1 z_1 z1 的维度是 25x1。

  • 应用 ReLU 激活函数: h 1 = ReLU ( z 1 ) h_1 = \text{ReLU}(z_1) h1=ReLU(z1)

其中: ReLU ( z 1 ) = max ⁡ ( 0 , z 1 ) \text{ReLU}(z_1) = \max(0, z_1) ReLU(z1)=max(0,z1)

  • 结果:第一个隐藏层的输出 h 1 h_1 h1 是 25 维的向量。
    在这里插入图片描述
1.2第一个隐藏层到第二个隐藏层:
  • 权重矩阵:连接第一个隐藏层到第二个隐藏层的权重矩阵,大小为 25x25。 W 2 ∈ R 25 × 25 W_2 \in \mathbb{R}^{25 \times 25} W2R25×25

  • 偏置项:第二个隐藏层的偏置项,大小为 25x1。 b 2 ∈ R 25 × 1 b_2 \in \mathbb{R}^{25 \times 1} b2R25×1

计算步骤

  • 执行矩阵乘法 z 2 = W 2 ⋅ h 1 + b 2 z_2 = W_2 \cdot h_1 + b_2 z2=W2h1+b2

z 2 z_2 z2的维度是 25x1。

  • 应用 ReLU 激活函数: h 2 = ReLU ( z 2 ) h_2 = \text{ReLU}(z_2) h2=ReLU(z2)

  • 结果:第二个隐藏层的输出 h 2 h_2 h2 仍然是 25 维的向量。

1.3第二个隐藏层到输出层:
  • 权重矩阵:连接第二个隐藏层到输出层的权重矩阵,大小为 10x25。 W 3 ∈ R 10 × 25 W_3 \in \mathbb{R}^{10 \times 25} W3R10×25

  • 偏置项:输出层的偏置项,大小为 10x1。 b 3 ∈ R 10 × 1 b_3 \in \mathbb{R}^{10 \times 1} b3R10×1

计算步骤

  • 执行矩阵乘法 z 3 = W 3 ⋅ h 2 + b 3 z_3 = W_3 \cdot h_2 + b_3 z3=W3h2+b3

z 3 z_3 z3的维度是 10x1。

  • 应用 Softmax 函数将输出转换为概率: Softmax ( z 3 ) i = e z 3 i ∑ j = 1 10 e z 3 j \text{Softmax}(z_3)_i = \frac{e^{z_{3i}}}{\sum_{j=1}^{10} e^{z_{3j}}} Softmax(z3)i=j=110ez3jez3i

Softmax 输出是 10 维的概率向量,表示输入属于 0-9 的概率。

2. 损失计算

使用交叉熵损失函数来计算预测输出与真实标签之间的误差,假设真实标签是 one-hot 编码的向量 y ∈ R 10 y \in \mathbb{R}^{10} yR10,其中,
y i = 1 y_i = 1 yi=1 表示真实类别, p i = Softmax ( z 3 ) i p_i = \text{Softmax}(z_3)_i pi=Softmax(z3)i 表示模型对类别 i i i 的预测概率。
在这里插入图片描述

交叉熵损失公式 L = − ∑ i = 1 10 y i log ⁡ ( p i ) L = -\sum_{i=1}^{10} y_i \log(p_i) L=i=110yilog(pi)

损失计算步骤

  • 对于每一个样本,计算预测类别对应的概率 p i p_i pi 的对数,然后计算损失 L L L

3. 反向传播(Backward Propagation)

反向传播的目标是通过链式法则计算损失函数对每层权重的偏导数,并更新权重矩阵。

3.1输出层到第二个隐藏层:
  • 计算损失对输出层的导数 ∂ L ∂ z 3 = Softmax ( z 3 ) − y \frac{\partial L}{\partial z_3} = \text{Softmax}(z_3) - y z3L=Softmax(z3)y

  • 计算损失对 W 3 W_3 W3的导数 ∂ L ∂ W 3 = ∂ L ∂ z 3 ⋅ h 2 T \frac{\partial L}{\partial W_3} = \frac{\partial L}{\partial z_3} \cdot h_2^T W3L=z3Lh2T

  • 计算损失对 b 3 b_3 b3的导数 ∂ L ∂ b 3 = ∂ L ∂ z 3 \frac{\partial L}{\partial b_3} = \frac{\partial L}{\partial z_3} b3L=z3L

3.2第二个隐藏层到第一个隐藏层:
  • 损失传播到第二层的输出 h 2 h_2 h2 ∂ L ∂ h 2 = W 3 T ⋅ ∂ L ∂ z 3 \frac{\partial L}{\partial h_2} = W_3^T \cdot \frac{\partial L}{\partial z_3} h2L=W3Tz3L

  • 计算 ReLU 激活函数的导数 ∂ L ∂ z 2 = ∂ L ∂ h 2 ⋅ ReLU ′ ( z 2 ) \frac{\partial L}{\partial z_2} = \frac{\partial L}{\partial h_2} \cdot \text{ReLU}'(z_2) z2L=h2LReLU(z2)

其中: ReLU ′ ( z 2 ) = { 1 if  z 2 > 0 0 if  z 2 ≤ 0 \text{ReLU}'(z_2) = \begin{cases} 1 & \text{if } z_2 > 0 \\ 0 & \text{if } z_2 \leq 0 \end{cases} ReLU(z2)={10if z2>0if z20

  • 计算损失对 W 2 W_2 W2的导数 ∂ L ∂ W 2 = ∂ L ∂ z 2 ⋅ h 1 T \frac{\partial L}{\partial W_2} = \frac{\partial L}{\partial z_2} \cdot h_1^T W2L=z2Lh1T

  • 计算损失对 b 2 b_2 b2的导数 ∂ L ∂ b 2 = ∂ L ∂ z 2 \frac{\partial L}{\partial b_2} = \frac{\partial L}{\partial z_2} b2L=z2L

3.3第一个隐藏层到输入层:
  • 损失传播到第一层的输出 h 1 h_1 h1 ∂ L ∂ h 1 = W 2 T ⋅ ∂ L ∂ z 2 \frac{\partial L}{\partial h_1} = W_2^T \cdot \frac{\partial L}{\partial z_2} h1L=W2Tz2L

  • 计算 ReLU 激活函数的导数 ∂ L ∂ z 1 = ∂ L ∂ h 1 ⋅ ReLU ′ ( z 1 ) \frac{\partial L}{\partial z_1} = \frac{\partial L}{\partial h_1} \cdot \text{ReLU}'(z_1) z1L=h1LReLU(z1)

  • 计算损失对 h 1 h_1 h1的导数 ∂ L ∂ W 1 = ∂ L ∂ z 1 ⋅ x T \frac{\partial L}{\partial W_1} = \frac{\partial L}{\partial z_1} \cdot x^T W1L=z1LxT

  • 计算损失对 b 1 b_1 b1的导数 ∂ L ∂ b 1 = ∂ L ∂ z 1 \frac{\partial L}{\partial b_1} = \frac{\partial L}{\partial z_1} b1L=z1L

4. 权重更新

使用梯度下降算法或 Adam 优化器来更新权重。

更新公式 W = W − η ⋅ ∂ L ∂ W W = W - \eta \cdot \frac{\partial L}{\partial W} W=WηWL

  • W W W是权重矩阵, η η η是学习率, ∂ L ∂ W \frac{\partial L}{\partial W} WL 是损失函数关于权重的梯度。

权重更新步骤在每一层执行:

  • 更新 W 1 W_1 W1, W 2 W_2 W2, W 3 W_3 W3和对应的偏置项 b 1 b_1 b1, b 2 b_2 b2, b 3 b_3 b3

5. 优化器(Adam)介绍

Adam 优化器通过结合动量和自适应学习率进行参数更新。详细的更新公式在上面的回答中已经给出。

1.一阶动量估计:
计算当前梯度 ∇ θ L \nabla_{\theta}L θL的加权平均,用来估计梯度的期望。这个一阶动量主要是累积之前的梯度,使得更新方向更加平滑。

m t = β 1 m t − 1 + ( 1 − β 1 ) ∇ θ L m_t = \beta_1 m_{t-1} + (1 - \beta_1) \nabla_{\theta}L mt=β1mt1+(1β1)θL
β 1 \beta_1 β1是一阶动量的衰减率,通常取值为 0.9。
m t m_t mt是当前的动量(梯度的指数加权平均)。

2.二阶矩估计:
计算当前梯度平方的加权平均,估计梯度的方差,用来调节学习率,避免更新步长过大。
v t = β 2 v t − 1 + ( 1 − β 2 ) ( ∇ θ L ) 2 v_t = \beta_2 v_{t-1} + (1 - \beta_2) (\nabla_{\theta}L)^2 vt=β2vt1+(1β2)(θL)2
β 2 \beta_2 β2是二阶动量的衰减率,通常取值为 0.999。
v t v_t vt是梯度平方的指数加权平均。

3.偏差修正:
由于 m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t在前几步可能会有较大的偏差,Adam 引入了偏差修正,减少估计的偏差。
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1β1tmt,v^t=1β2tvt
m ^ t \hat{m}_t m^t v ^ t \hat{v}_t v^t是偏差修正后的动量和二阶矩估计.

4.参数更新:
使用修正后的动量和方差来更新参数。Adam 的更新方式是自适应的,能根据梯度的历史动态调整学习率。
W t + 1 = W t − η m ^ t v ^ t + ϵ W_{t+1} = W_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} Wt+1=Wtηv^t +ϵm^t
η η η是学习率,通常取值在 0.001 左右。
$ \epsilon$是一个小的平滑项,避免除以零,通常为 1 0 − 8 10^{-8} 108


Reference:

  1. TensorFlow Documentation
  2. CS231n Convolutional Neural Networks for Visual Recognition
  3. Adam Optimizer Paper
  4. Gradient Descent and Backpropagation Overview
  5. https://www.deeplearningbook.org/
  6. http://cs231n.github.io/optimization-2/

相关文章:

神经网络构建原理(以MINIST为例)

神经网络构建原理(以MINIST为例) 在 MNIST 手写数字识别任务中,构建神经网络并训练模型来进行分类是经典的深度学习应用。MNIST 数据集包含 28x28 像素的手写数字图像(0-9),任务是构建一个神经网络,能够根据输入的图像…...

【ArcGIS微课1000例】0123:数据库中要素类批量转为shapefile

除了ArcGIS之外的其他GIS平台,想要打开ArcGIS数据库,可能无法直接打开,为了便于使用shp,建议直接将数据库中要素类批量转为shapefile。 文章目录 一、连接至数据库二、要素批量转shp一、连接至数据库 打开ArcMap,或者打开ArcCatalog,找到数据库连接,如下图: 数据库为个…...

【Elasticsearch系列十九】评分机制详解

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…...

神经网络通俗理解学习笔记(3)注意力神经网络

Tansformer 什么是注意力机制注意力的计算键值对注意力和多头注意力自注意力机制注意力池化及代码实现Transformer模型Transformer代码实现BERT 模型GPT 系列模型GPT-1模型思想GPT-2模型思想GPT-3 模型思想 T5模型ViT模型Swin Transformer模型GPT模型代码实现 什么是注意力机制…...

【C#】 EventWaitHandle的用法

EventWaitHandle 是 C# 中用于线程间同步的一个类,它提供了对共享资源的访问控制,以及线程间的同步机制。EventWaitHandle 类位于 System.Threading 命名空间下,主要用于实现互斥访问、信号量控制等场景。 创建 EventWaitHandle 创建一个 E…...

设计模式之结构型模式例题

答案:A 知识点 创建型 结构型 行为型模式 工厂方法模式 抽象工厂模式 原型模式 单例模式 构建器模式 适配器模式 桥接模式 组合模式 装饰模式 外观模式 享元模式 代理模式 模板方法模式 职责链模式 命令模式 迭代器模式 中介者模式 解释器模式 备忘录模式 观…...

camtasia2024绿色免费安装包win+mac下载含2024最新激活密钥

Hey, hey, hey!亲爱的各位小伙伴,今天我要给大家带来的是Camtasia2024中文版本,这款软件简直是视频制作爱好者的福音啊! camtasia2024绿色免费安装包winmac下载,点击链接即可保存。 先说说这个版本新加的功能吧&#…...

如何导入一个Vue并成功运行

注意1:要确保自己已经成功创建了一个Vue项目,创建项目教程在如何创建Vue项目 注意2:以下操作均在VS Code,教程在VS Code安装教程 一、Vue项目导入VS Code 1.点击文件,然后点击将文件添加到工作区 2. 选择自己的vue项…...

封装svg图片

前言 项目中有大量svg图片,为了方便引入,所以对svg进行了处理 一、svg是什么? svg是可缩放矢量图形,是一种图片格式 二、使用步骤 1.创建icons文件夹 将icons文件夹放进src中,并创建一个svg文件夹和index.js&…...

tomcat的Catalinalog和localhostlog乱码

找到tomcat安装目录的loging文件 乱码这两个由UTF-8改为GBK...

行人持刀检测数据集 voc yolo

行人持刀检测数据集 9000张 持刀检测 带标注 voc yolo 行人持刀检测数据集 数据集描述 该数据集旨在用于行人持刀行为的检测任务,涵盖了多种场景下的行人图像,特别是那些携带刀具的行人。数据集包含大量的图像及其对应的标注信息,可用于训练…...

基于51单片机的汽车倒车防撞报警器系统

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 本课题基于微控制器控制器, 设计一款汽车倒车防撞报警器系统。 要求: 要求:1.配有距离, 用于把车和障碍物之间的距离信号送入控制器。 2.配有报警系…...

NLP 文本匹配任务核心梳理

定义 本质上是做了意图的识别 判断两个内容的含义(包括相似、矛盾、支持度等)侠义 给定一组文本,判断语义是否相似Yi 分值形式给出相似度 广义 给定一组文本,计算某种自定义的关联度Text Entailment 判断文本是否能支持或反驳这个…...

FastAPI 的隐藏宝石:自动生成 TypeScript 客户端

在现代 Web 开发中,前后端分离已成为标准做法。这种架构允许前端和后端独立开发和扩展,但同时也带来了如何高效交互的问题。FastAPI,作为一个新兴的 Python Web 框架,提供了一个优雅的解决方案:自动生成客户端代码。本…...

了解云容器实例云容器实例(Cloud Container Instance)

1.什么是云容器实例? 云容器实例(Cloud Container Instance, CCI)服务提供 Serverless Container(无服务器容器)引擎,让您无需创建和管理服务器集群即可直接运行容器。 Serverless是一种架构理念…...

OpenStack Yoga版安装笔记(十三)neutron安装

1、官方文档 OpenStack Installation Guidehttps://docs.openstack.org/install-guide/ 本次安装是在Ubuntu 22.04上进行,基本按照OpenStack Installation Guide顺序执行,主要内容包括: 环境安装 (已完成)OpenStack…...

[系列]参数估计与贝叶斯推断

系列 点估计极大似然估计贝叶斯估计(统计学)——最小均方估计和最大后验概率估计贝叶斯估计(模式识别)线性最小均方估计最小二乘估计极大似然估计&贝叶斯估计极大似然估计&最大后验概率估计线性最小均方估计&最小二乘…...

【Pyside】pycharm2024配置conda虚拟环境

知识拓展 Pycharm 是一个由 JetBrains 开发的集成开发环境(IDE),它主要用于 Python 编程语言的开发。Pycharm 提供了代码编辑、调试、版本控制、测试等多种功能,以提高 Python 开发者的效率。 Pycharm 与 Python 的关系 Pycharm 是…...

【RabbitMQ 项目】服务端:数据管理模块之消息队列管理

文章目录 一.编写思路二.代码实践 一.编写思路 定义消息队列 名字是否持久化 定义队列持久化类(持久化到 sqlite3) 构造函数(只能成功,不能失败) 如果数据库(文件)不存在则创建打开数据库打开 msg_queue_table 数据库表 插入队列移除队列将数据库中的队列恢复到内存…...

SDKMAN!软件开发工具包管理器

认识一下SDKMAN!(The Software Development Kit Manager)是您在Unix系统上轻松管理多个软件开发工具包的可靠伴侣。想象一下,有不同版本的SDK,需要一种无感知的方式在它们之间切换。SDKMAN拥有易于使用的命令行界面(CLI)和API。其…...

《使用 LangChain 进行大模型应用开发》学习笔记(四)

前言 本文是 Harrison Chase (LangChain 创建者)和吴恩达(Andrew Ng)的视频课程《LangChain for LLM Application Development》(使用 LangChain 进行大模型应用开发)的学习笔记。由于原课程为全英文视频课…...

gbase8s数据库常见的索引扫描方式

1 顺序扫描(Sequential scan):数据库服务器按照物理顺序读取表中的所有记录。 常发生在表上无索引或者数据量很少或者一些无法使用索引的sql语句中 2 索引扫描(Index scan):数据库服务器读取索引页&#…...

边缘智能-大模型架构初探

R2Cloud接口 机器人注册 请求和应答 注册是一个简单的 HTTP 接口,根据机器人/用户信息注册,创建一个新机器人。 请求 URL URLhttp://ip/robot/regTypePOSTHTTP Version1.1Content-Typeapplication/json 请求参数 Param含义Rule是否必须缺省roboti…...

《python语言程序设计》2018版第8章18题几何circle2D类(上部)

一、利用第7章的内容来做前5个点 第一章之1--从各种角度来测量第一章之2--各种结果第二章之1--建立了针对比对点在圆内的几段第二章之2--利用建立的对比代码,得出的第2点位置 第一章之1–从各种角度来测量 class Circle2D:def __init__(self, x, y, radius):self._…...

nginx upstream转发连接错误情况研究

本次测试用到3台服务器: 192.168.10.115:转发服务器A 192.168.10.209:upstream下服务器1 192.168.10.210:upstream下服务器2 1台客户端:192.168.10.112 服务器A中nginx主要配置如下: log_format main…...

alias 后门从入门到应急响应

目录 1. alias 后门介绍 2. alias 后门注入方式 2.1 方式一(以函数的方式执行) 2.2 方式二(执行python脚本) 3.应急响应 3.1 查看所有连接 3.2 通过PID查看异常连接的进程,以及该进程正在执行的命令行命令 3.3 查看别名 3.4 其他情况 3.5 那么检查这些…...

【远程调用PythonAPI-flask】

文章目录 前言一、Pycharm创建flask项目1.创建虚拟环境2.创建flask项目 二、远程调用PythonAPI——SpringBoot项目集成1.修改PyCharm的host配置2.防火墙设置3.SpringBoot远程调用PythonAPI 前言 解决Pycharm运行Flask指定ip、端口更改无效的问题 首先先创建一个新的flask项目&…...

[今日Arxiv] 思维迭代:利用内心对话进行自主大型语言模型推理

思维迭代:利用内心对话进行自主大型语言模型推理 Iteration of Thought: Leveraging Inner Dialogue for Autonomous Large Language Model Reasoning URL:https://arxiv.org/abs/2409.12618 注:翻译可能存在误差,详细内容建议…...

glTF格式:WebGL应用的3D资产优化解决方案

摘要 glTF作为一种高效的3D资产格式,为WebGL、OpenGL ES和OpenGL运行时的应用提供了强有力的支持。它不仅简化了3D模型的传输与加载流程,还通过优化资产大小,使得打包、解包更加便捷。本文将深入探讨glTF格式的优势,并提供实用的代…...

Unity3D入门(一) : 第一个Unity3D项目,实现矩形自动旋转,并导出到Android运行

1. Unity3D介绍 Unity3D是虚拟现实行业中,使用率较高的一款软件。 它有着强大的功能,是让玩家轻松创建三维视频游戏、建筑可视化、实时三维动画等互动内容的多平台、综合型 虚拟现实开发工具。是一个全面整合的专业引擎。 2. Unity安装 官网 : Unity…...