自己动手实现一个深度学习算法——二、神经网络的实现
文章目录
- 1. 神经网络概述
- 1)表示
- 2)激活函数
- 3)sigmoid函数
- 4)阶跃函数的实现
- 5)sigmoid函数的实现
- 6)sigmoid函数和阶跃函数的比较
- 7)非线性函数
- 8)ReLU函数
- 2.三层神经网络的实现
- 1)结构
- 2)代码实现
- 3.输出层的设计
- 1)概述
- 2)softmax函数
- 3)实现softmax函数时的注意事项
- 4)softmax函数的特征
- 5)输出层的神经元数量
- 4.手写数字识别
- 1)MNIST数据集
- 2)实现
- 3)批处理
神经网络可以自动地从数据中学习到合适的权重参数。
1. 神经网络概述
1)表示
神经网络信号传递类似于感知机。最左边的一列称为输入层,最右边的一列称为输出层,中间的一列称为中间层。中间层有时也称为隐藏层。实现中,输入层到输出层依次称为第 0层、第1 层、第 2 层
2)激活函数
h(x)函数会将输入信号的总和转换为输出信号,这种函数一般称为激活函数(activation function)。如下:
y = h(b + w1x1+ w2x2)
如果激活函数如下,即以阈值为界,一旦输入超过阈值,就切换输出。这样的函数称为“阶跃函数”。因此,可以说感知机中使用了阶跃函数作为
激活函数。
3)sigmoid函数
神经网络中经常使用的一个激活函数就是sigmoid函数(sigmoid function)。表达式如下:
神经网络中用sigmoid函数作为激活函数,进行信号的转换,转换后的信号被传送给下一个神经元。
感知机和神经网络的主要区别就在于这个激活函数。
4)阶跃函数的实现
# coding: utf-8
import numpy as np
import matplotlib.pylab as pltdef step_function(x):# return np.array(x > 0, dtype=np.int)return np.array(x > 0, dtype=int)X = np.arange(-5.0, 5.0, 0.1)
Y = step_function(X)
plt.plot(X, Y)
plt.ylim(-0.1, 1.1) # 指定图中绘制的y轴的范围
plt.show()
5)sigmoid函数的实现
# coding: utf-8
import numpy as np
import matplotlib.pylab as pltdef sigmoid(x):return 1 / (1 + np.exp(-x)) X = np.arange(-5.0, 5.0, 0.1)
Y = sigmoid(X)
plt.plot(X, Y)
plt.ylim(-0.1, 1.1)
plt.show()
6)sigmoid函数和阶跃函数的比较
sigmoid函数是一条平滑的曲线,输出随着输入发生连续性的变化。sigmoid函数的平滑性对神经网络的学习具有重要意义。
当输入信号为重要信息时,阶跃函数和sigmoid函数都会输出较大的值;当输入信号为不重要的信息时,两者都输出较小的值。
不管输入信号有多小,或者有多大,输出信号的值都在0到1之间。
# coding: utf-8
import numpy as np
import matplotlib.pylab as pltdef sigmoid(x):return 1 / (1 + np.exp(-x)) def step_function(x):return np.array(x > 0, dtype=np.int)x = np.arange(-5.0, 5.0, 0.1)
y1 = sigmoid(x)
y2 = step_function(x)plt.plot(x, y1)
plt.plot(x, y2, 'k--')
plt.ylim(-0.1, 1.1) #指定图中绘制的y轴的范围
plt.show()
7)非线性函数
阶跃函数和sigmoid函数还有其他共同点,就是两者均为非线性函数。
神经网络的激活函数必须使用非线性函数。线性函数的问题在于,不管如何加深层数,总是存在与之等效的“无隐藏层的神经网络”
为了发挥叠加层所带来的优势,激活函数必须使用非线性函数。
8)ReLU函数
sigmoid函数很早就开始被使用了,而最近则主要使用ReLU(Rectified Linear Unit)函数。
ReLU 函数也是一种激活函数,可以表示为下面的式
ReLU函数的实现如下,
# coding: utf-8
import numpy as np
import matplotlib.pylab as pltdef relu(x):return np.maximum(0, x)x = np.arange(-5.0, 5.0, 0.1)
y = relu(x)
plt.plot(x, y)
plt.ylim(-1.0, 5.5)
plt.show()
2.三层神经网络的实现
1)结构
3层神经网络:输入层(第0层)有2个神经元,第1个隐藏层(第1层)有3个神经元,第2个隐藏层(第2层)有2个神经元,输出层(第3层)有2个神经元,结构
如下,
2)代码实现
# coding: utf-8
import numpy as np
from common.functions import sigmoid,identity_functiondef init_network():network = {}network['W1'] = np.array([[0.1, 0.3, 0.5], [0.2, 0.4, 0.6]])network['b1'] = np.array([0.1, 0.2, 0.3])network['W2'] = np.array([[0.1, 0.4], [0.2, 0.5], [0.3, 0.6]])network['b2'] = np.array([0.1, 0.2])network['W3'] = np.array([[0.1, 0.3], [0.2, 0.4]])network['b3'] = np.array([0.1, 0.2])return network
def forward(network, x):W1, W2, W3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, W3) + b3y = identity_function(a3)return y
network = init_network()
x = np.array([1.0, 0.5])
y = forward(network, x)
print(y) # [ 0.31682708 0.69627909]
3.输出层的设计
1)概述
机器学习的问题大致可以分为分类问题和回归问题。分类问题是数据属于哪一个类别的问题。比如,区分图像中的人是男性还是女性的问题就是分类问题。而回归问题是根据某个输入预测一个(连续的)数值的问题。比如,根据一个人的图像预测这个人的体重的问题就是回归问题(类似“57.4kg”这样的预测)。
输出层的激活函数用σ()表示,不同于隐藏层的激活函数h()(σ读作sigma)。
输出层所用的激活函数,要根据求解问题的性质决定。一般地,回归问题可以使用恒等函数,二元分类问题可以使用sigmoid函数,多元分类问题可以使用softmax函数。
恒等函数会将输入按原样输出,对于输入的信息,不加以任何改动地直接输出。
2)softmax函数
分类问题中使用的softmax函数可以用下面的式表示。
softmax 函数的分子是输入信号 ak的指数函数,分母是所有输入信号的指数函数的和。输出层的各个神经元都受到所有输入信号的影响。
3)实现softmax函数时的注意事项
softmax函数的实现中要进行指数函数的运算,但是此时指数函数的值很容易变得非常大。结果可能会返回一个表示无穷大的inf。如果在这些超大值之间进行除法运算,结果会出现“不确定”的情况。这个问题称为溢出。
解决方式如下:
def softmax(a):#通过减去输入信号中的最大值c = np.max(a)exp_a = np.exp(a - c) # 溢出对策sum_exp_a = np.sum(exp_a)y = exp_a / sum_exp_areturn y
4)softmax函数的特征
softmax函数的输出是0.0到1.0之间的实数。并且,softmax函数的输出值的总和是1。
一般而言,神经网络只把输出值最大的神经元所对应的类别作为识别结果。并且,即便使用softmax函数,输出值最大的神经元的位置也不会变。
**因此,神经网络在进行分类时,输出层的softmax函数可以省略。**在实际的问题中,由于指数函数的运算需要一定的计算机运算量,因此输出层的softmax函数
一般会被省略。
在输出层使用softmax函数是因为它和神经网络的学习有关系。
5)输出层的神经元数量
输出层的神经元数量需要根据待解决的问题来决定。对于分类问题,输出层的神经元数量一般设定为类别的数量。
4.手写数字识别
假设学习已经全部结束,我们使用学习到的参数,先实现神经网络的“推理处理”。这个推理处理也称为神经网络的前向传播(forward propagation)。
1)MNIST数据集
MNIST是机器学习领域最有名的数据集之一,被应用于从简单的实验到发表的论文研究等各种场合。
MNIST 数据集是由 0 到 9 的数字图像构成的(图 3-24)。训练图像有 6 万张,测试图像有1万张,这些图像可以用于学习和推理。MNIST数据集的一般使用方法是,先用训练图像进行学习,再用学习到的模型度量能在多大程度上对测试图像进行正确的分类。
显示图形代码
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
from dataset.mnist import load_mnist
from PIL import Imagedef img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
print(x_train.shape)
print(t_train.shape)
img = x_train[0]
label = t_train[0]
print(label) # 5
print(img.shape)
print(img.shape) # (784,)
img = img.reshape(28, 28) # 把图像的形状变为原来的尺寸
print(img.shape) # (28, 28)img_show(img)
load_mnist 函数以“(训练图像,训练标签),(测试图像,测试标签)”的形式返回读入的MNIST数据。此外,还可以像load_mnist(normalize=True, flatten=True, one_hot_label=False) 这 样,设 置 3 个 参 数。第 1 个 参 数normalize 设置是否将输入图像正规化为 0.0~1.0 的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。第2个参数flatten设置是否展开输入图像(变成一维数组)。如果将该参数设置为False,则输入图像为1×28×28 的三维数组;若设置为 True,则输入图像会保存为由 784 个元素构成的一维数组。第3个参数one_hot_label设置是否将标签保存为one-hot 表示(one-hot representation)onehot 表示是仅正确解标签为 1,其余皆为0的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,只是像7、2这样简单保存正确解标签;one_hot_label为True时,标签则保存为one-hot表示。
2)实现
神经网络的输入层有784个神经元,输出层有10个神经元。输入层的784这个数字来源于图像大小的28×28 = 784,输出层的 10 这个数字来源于 10 类别分类(数
字0到9,共10类别)。此外,这个神经网络有2个隐藏层,第1个隐藏层有50 个神经元,第 2 个隐藏层有 100 个神经元。
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmax#读入写字数据集,进行了归一化处理的一维数组,保存了正确解的标签
def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)return x_test, t_test#读入保存在 pickle 文件 sample_weight.pkl 中的学习到的权重参数.这个文件中以字典变量的形式保存了权重和偏置参数。
def init_network():with open("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef predict(network, x):W1, W2, W3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, W3) + b3y = softmax(a3)return yx, t = get_data()
network = init_network()
accuracy_cnt = 0
for i in range(len(x)):y = predict(network, x[i])p= np.argmax(y) # 获取概率最高的元素的索引if p == t[i]:accuracy_cnt += 1print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
将 normalize 设置成 True 后,函数内部会进行转换,将图像的各个像素值除以255,使得数据的值在0.0~1.0的范围内。像这样把数据限定到某个范围内的处理称为正规化(normalization)。此外,对神经网络的输入数据进行某种既定的转换称为预处理(pre-processing)
3)批处理
# coding: utf-8
import sys, os
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
import numpy as np
import pickle
from dataset.mnist import load_mnist
from common.functions import sigmoid, softmaxdef get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)return x_test, t_testdef init_network():with open("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef predict(network, x):w1, w2, w3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, w1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, w2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, w3) + b3y = softmax(a3)return yx, t = get_data()
network = init_network()batch_size = 100 # 批数量
accuracy_cnt = 0#按照batch_size间隔,从0获取元素
for i in range(0, len(x), batch_size):x_batch = x[i:i+batch_size]y_batch = predict(network, x_batch)#按照1维取最大值,即按行取最大值p = np.argmax(y_batch, axis=1)accuracy_cnt += np.sum(p == t[i:i+batch_size])print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
相关文章:

自己动手实现一个深度学习算法——二、神经网络的实现
文章目录 1. 神经网络概述1)表示2)激活函数3)sigmoid函数4)阶跃函数的实现5)sigmoid函数的实现6)sigmoid函数和阶跃函数的比较7)非线性函数8)ReLU函数 2.三层神经网络的实现1)结构2&…...

gRPC源码剖析-Builder模式
一、Builder模式 1、定义 将一个复杂对象的构建与表示分离,使得同样的构建过程可以创建不同的的表示。 2、适用场景 当创建复杂对象的算法应独立于该对象的组成部分以及它们的装配方式时。 当构造过程必须允许被构造的对象有不同的表示时。 说人话:…...
ARM传输数据以及移位操作
3.2.2 数据传送指令 LDR/STR指令用来在寄存器和内存之间输送数据。如果我们想要在寄存器之间传送数据,则可以使用MOV指令。MOV指令的格式如下。 MOV {cond} {s}Rd, oprand2 MOV {cond} {s}Rd, oprand2 其中,{cond}为条件指令可选项,{s}用来表…...

06、如何将对象数组里 obj 的 key 值变成动态的(即:每一个对象对应的 key 值都不同)
1、数据情况: 其一、从后端拿到的数据为: let arr [1,3,6,10,11,23,24] 其二、目标数据为: [{vlan_1: 1, value: 1}, {vlan_3: 3, value: 1}, {vlan_6: 6, value: 1}, {vlan_10: 10, value: 1}, {vlan_11: 11, value: 1}, {vlan_23: 23, v…...

ngx_http_request_s
/* 罗剑锋老师的注释参考: https://github.com/chronolaw/annotated_nginx/blob/master/nginx/src/http/ngx_http_request.h */struct ngx_http_request_s {uint32_t signature; /* "HTTP" */ngx_connection_t …...

Docker 学习路线 2:底层技术
了解驱动Docker的核心技术将让您更深入地了解Docker的工作原理,并有助于您更有效地使用该平台。 Linux容器(LXC) Linux容器(LXC)是Docker的基础。 LXC是一种轻量级的虚拟化解决方案,允许多个隔离的Linux系…...
UEFI实战——显示图片
一、准备工作 1.1 BMP格式图片 参考:BMP格式详解获取“BMP格式详解”文档里的图片,命名为Logo.bmp将Logo.bmp图片放到U盘里,U盘格式FAT32二、实例代码 2.1 代码结构 TextPkg/ ├── Display.c ├── GetFile.c ├── Test.c ├── Test.dsc ├── Test.h └── Tes…...

Ansible中的playbook
目录 一、playbook简介 二、playbook的语法 三、playbook的核心组件 四、playbook的执行命令 五、vim 设定技巧 六、基本示例 一、playbook简介 1、playbook与ad-hoc相比,是一种完全不同的运用。 2、playbook是一种简单的配置管理系统与多机器部署系统的基础…...

怎样去除视频中的杂音,保留人声部分?
怎样去除视频中的杂音,保留人声部分?这个简单嘛!两种办法可以搞定:一是进行音频降噪,把无用的杂音消除掉;二是提取人声,将要保留的人声片段提取出来。 这就将两种实用的办公都分享出来…...

基于Qt QTreeView|QTreeWidget控件使用简单版
头文件解析: 这是一个C++代码文件,定义了一个名为MainWindow的类。以下是对每一句的详细解释: ```cpp #ifndef MAINWINDOW_H #define MAINWINDOW_H ``` 这是一个条件编译指令,用于避免头文件的重复包含。`MAINWINDOW_H`是一个宏定义,用于唯一标识这个头文件。 ```cpp #…...

edge浏览器的隐藏功能
1. edge://version 查看版本信息 2. edge://flags 特性界面 具体到某一特性:edge://flags/#overlay-scrollbars 3. edge://settings设置界面 详情可参考chrome: 4. edge://extensions 扩展程序页面 5. edge://net-internals 网络事件信息 6. edge://component…...

安卓抓包之小黄鸟
下载安装 下载地址: https://download.csdn.net/download/yijianxiangde100/88496463 安装apk 即可。 证书配置:...

Django中的FBV和CBV
一、两者的区别 1、在我们日常学习Django中,都是用的FBV(function base views)方式,就是在视图中用函数处理各种请求。而CBV(class base view)则是通过类来处理请求。 2、Python是一个面向对象的编程语言…...

信息泄露--
大唐电信AC简介 大唐电信科技股份有限公司是电信科学技术研究院(大唐电信科技产业集团)控股的的高科技企业,大唐电信已形成集成电路设计、软件与应用、终端设计、移动互联网四大产业板块。 大唐电信AC集中管理平台存在弱口令及敏感信息泄漏漏…...
C#WPF文本格式化模式实例
本文演示C#WPF文本格式化模式实例 WPF 文本渲染优缺点 WPF中的文本渲染和旧式的基于 GDI的应用程序的文本染有很大区别。很大一部分区 别是由于 WPF 的设备无关显示系统造成的,但 WPF 中的文本染也得到了显著增强,能更清晰地显示文本,在 LCD 监视器上尤其如此。 然而,W…...
嵌入式云平台一些基础概念的理解
1.SDK SDK是Software Development Kit的缩写,译为”软件开发工具包”,通常是为辅助开发某类软件而编写的特定软件包,框架集合等,SDK一般包含相关文档,范例和工具。 我自己的理解就似乎,SDK也就是软件开发工具包,他会为其使用者提供一些封装好的接口&…...

【项目管理】生命周期风险评估
规划阶段目标:识别系统的业务战略,以支撑系统的安全需求及安全战略 规划阶段评估重点:1、本阶段不需要识别资产和脆弱性;2、应根据被评估对象的应用对象、应用环境、业务状况、操作要求等方面识别威胁; 设计阶段目标…...

力扣 搜索旋转排序数组 二分
👨🏫 33. 搜索旋转排序数组 class Solution {public int search(int[] nums, int target){int l 0;int r nums.length - 1;while (l < r){int m l r >> 1;//else大法,把无序段抛给else,if只处理有序段 // 需要特…...

【软件测试】个人博客项目测试报告
目录 1.报告概要 2、测试环境 3、手动测试用例编写 4、自动化测试用例 1.报告概要 测试对象:基于SSM项目的博客系统。 测试目的:检测博客项目是否符合预期,并且对测试知识进行练习和巩固。 测试点:主要针对常用的功能进行测…...

Express框架开发接口之今日推荐等模块
1.初始化 const handleDB require(../handleDB/index) // 获取全部模块 exports.allModule (req, res) > {(async function () {})() } // 更新或者添加模块 exports.upModule (req, res) > {(async function () {})() } // 根据id删除模块 exports.delModule (req, …...

19c补丁后oracle属主变化,导致不能识别磁盘组
补丁后服务器重启,数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后,存在与用户组权限相关的问题。具体表现为,Oracle 实例的运行用户(oracle)和集…...

LeetCode - 394. 字符串解码
题目 394. 字符串解码 - 力扣(LeetCode) 思路 使用两个栈:一个存储重复次数,一个存储字符串 遍历输入字符串: 数字处理:遇到数字时,累积计算重复次数左括号处理:保存当前状态&a…...

关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...

SpringTask-03.入门案例
一.入门案例 启动类: package com.sky;import lombok.extern.slf4j.Slf4j; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cache.annotation.EnableCach…...

vue3+vite项目中使用.env文件环境变量方法
vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量,这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

如何理解 IP 数据报中的 TTL?
目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...

有限自动机到正规文法转换器v1.0
1 项目简介 这是一个功能强大的有限自动机(Finite Automaton, FA)到正规文法(Regular Grammar)转换器,它配备了一个直观且完整的图形用户界面,使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...
JS手写代码篇----使用Promise封装AJAX请求
15、使用Promise封装AJAX请求 promise就有reject和resolve了,就不必写成功和失败的回调函数了 const BASEURL ./手写ajax/test.jsonfunction promiseAjax() {return new Promise((resolve, reject) > {const xhr new XMLHttpRequest();xhr.open("get&quo…...
【前端异常】JavaScript错误处理:分析 Uncaught (in promise) error
在前端开发中,JavaScript 异常是不可避免的。随着现代前端应用越来越多地使用异步操作(如 Promise、async/await 等),开发者常常会遇到 Uncaught (in promise) error 错误。这个错误是由于未正确处理 Promise 的拒绝(r…...
Leetcode33( 搜索旋转排序数组)
题目表述 整数数组 nums 按升序排列,数组中的值 互不相同 。 在传递给函数之前,nums 在预先未知的某个下标 k(0 < k < nums.length)上进行了 旋转,使数组变为 [nums[k], nums[k1], …, nums[n-1], nums[0], nu…...