第四章.误差反向传播法—误差反向传播法实现手写数字识别神经网络
第四章.误差反向传播法
4.3 误差反向传播法实现手写数字识别神经网络
通过像组装乐高积木一样组装第四章中实现的层,来构建神经网络。
1.神经网络学习全貌图
1).前提:
- 神经网络存在合适的权重和偏置,调整权重和偏置以便拟合训练数据的过程称为“学习”,神经网络的学习分成下面4个步骤。
2).步骤1 (mini-batch):
- 从训练数据中随机选出一部分数据,这部分数据称为mini-batch,我们的目标是减少mini-batch损失函数的值。
3).步骤2 (计算梯度):
- 为了减少mini_batch损失函数的值,需要求出各个权重参数的梯度,梯度表示损失函数的值减少最多的方向。
4).步骤3 (更新参数):
- 将权重参数沿梯度方向进行微小更新
5).步骤4 (重复):
- 重复步骤1,步骤2,步骤3
2.手写数字识别神经网络的实现:(2层)
# 误差反向传播法实现手写数字识别神经网络import numpy as np
import matplotlib.pyplot as plt
import sys, ossys.path.append(os.pardir)
from dataset.mnist import load_mnist
from collections import OrderedDictclass Affine:def __init__(self, W, b):self.W = Wself.b = bself.x = Noneself.original_x_shape = None# 权重和偏置参数的导数self.dW = Noneself.db = None# 向前传播def forward(self, x):self.original_x_shape = x.shapex = x.reshape(x.shape[0], -1)self.x = xout = np.dot(self.x, self.W) + self.breturn out# 反向传播def backward(self, dout):dx = np.dot(dout, self.W.T)self.dW = np.dot(self.x.T, dout)self.db = np.sum(dout, axis=0)dx = dx.reshape(*self.original_x_shape) # 还原输入数据的形状(对应张量)return dxclass ReLU:def __init__(self):self.mask = Nonedef forward(self, x):self.mask = (x <= 0)out = x.copy()out[self.mask] = 0return outdef backward(self, dout):dout[self.mask] = 0dx = doutreturn dxclass SoftmaxWithLoss:def __init__(self):self.loss = Noneself.y = Noneself.t = None# 输出层函数:softmaxdef softmax(self, x):if x.ndim == 2:x = x.Tx = x - np.max(x, axis=0)y = np.exp(x) / np.sum(np.exp(x), axis=0)return y.Tx = x - np.max(x) # 溢出对策y = np.exp(x) / np.sum(np.exp(x))return y# 误差函数:交叉熵误差def cross_entropy_error(self, y, t):if y.ndim == 1:y = y.reshape(1, y.size)t = t.reshape(1, t.size)# 监督数据是one_hot_label的情况下,转换为正确解标签的索引if t.size == y.size:t = t.argmax(axis=1)batch_size = y.shape[0]return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_sizedef forward(self, x, t):self.t = tself.y = self.softmax(x)self.loss = self.cross_entropy_error(self.y, self.t)return self.lossdef backward(self, dout=1):batch_size = self.t.shape[0]if self.t.size == self.y.size:dx = (self.y - self.t) / batch_sizeelse:dx = self.y.copy()dx[np.arange(batch_size), self.t] -= 1dx = dx / batch_sizereturn dxclass TwoLayerNet:# 初始化def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):# 初始化权重self.params = {}self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)self.params['b1'] = np.zeros(hidden_size)self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)self.params['b2'] = np.zeros(output_size)# 生成层self.layers = OrderedDict()self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])self.layers['ReLU'] = ReLU()self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])self.lastLayer = SoftmaxWithLoss()def predict(self, x):for layer in self.layers.values():x = layer.forward(x)return xdef loss(self, x, t):y = self.predict(x)loss = self.lastLayer.forward(y, t)return lossdef accuracy(self, x, t):y = self.predict(x)y = np.argmax(y, axis=1)if t.ndim != 1: t = np.argmax(t, axis=1)accuracy = np.sum(y == t) / float(t.shape[0])return accuracy# 微分函数def numerical_gradient1(self, f, x):h = 1e-4grad = np.zeros_like(x)it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])while not it.finished:idx = it.multi_indextmp_val = x[idx]x[idx] = float(tmp_val) + hfxh1 = f(x) # f(x+h)x[idx] = tmp_val - hfxh2 = f(x) # f(x-h)grad[idx] = (fxh1 - fxh2) / (2 * h)x[idx] = tmp_val # 还原值it.iternext()return grad# 通过数值微分计算关于权重参数的梯度def numerical_gradient(self, x, t):loss_W = lambda W: self.loss(x, t)grad = {}grad['W1'] = self.numerical_gradient1(loss_W, self.params['W1'])grad['b1'] = self.numerical_gradient1(loss_W, self.params['b1'])grad['W2'] = self.numerical_gradient1(loss_W, self.params['W2'])grad['b2'] = self.numerical_gradient1(loss_W, self.params['b2'])return grad# 通过误差反向传播法计算权重参数的梯度误差def gradient(self, x, t):# 正向传播self.loss(x, t)# 反向传播dout = 1dout = self.lastLayer.backward(dout)layers = list(self.layers.values())layers.reverse()for layer in layers:dout = layer.backward(dout)# 设定grads = {}grads['W1'] = self.layers['Affine1'].dWgrads['b1'] = self.layers['Affine1'].dbgrads['W2'] = self.layers['Affine2'].dWgrads['b2'] = self.layers['Affine2'].dbreturn grads# 读入数据
def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)return (x_train, t_train), (x_test, t_test)# 读入数据
(x_train, t_train), (x_test, t_test) = get_data()network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)iters_num = 10000
train_size = x_train.shape[0]
batch_size = 100
lr = 0.1
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)for i in range(iters_num):batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]# 通过误差反向传播法求梯度grad = network.gradient(x_batch, t_batch)# 更新for key in ('W1', 'b1', 'W2', 'b2'):network.params[key] -= lr * grad[key]loss = network.loss(x_batch, t_batch)train_loss_list.append(loss)if i % iter_per_epoch == 0:train_acc = network.accuracy(x_train, t_train)train_acc_list.append(train_acc)test_acc = network.accuracy(x_test, t_test)test_acc_list.append(test_acc)print('train_acc,test_acc|', str(train_acc) + ',' + str(test_acc))# 绘制识别精度图像
plt.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文乱码
plt.rcParams['axes.unicode_minus'] = False # 解决负号不显示的问题plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
x_data = np.arange(0, len(train_acc_list))
plt.plot(x_data, train_acc_list, 'b')
plt.plot(x_data, test_acc_list, 'r')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.ylim(0.0, 1.0)
plt.title('训练数据和测试数据的识别精度')
plt.legend(['train_acc', 'test_acc'])plt.subplot(1, 2, 2)
x_data = np.arange(0, len(train_loss_list))
plt.plot(x_data, train_loss_list, 'g')
plt.xlabel('iters_num')
plt.ylabel('loss')
plt.title('损失函数')
plt.show()
3.结果展示
相关文章:

第四章.误差反向传播法—误差反向传播法实现手写数字识别神经网络
第四章.误差反向传播法 4.3 误差反向传播法实现手写数字识别神经网络 通过像组装乐高积木一样组装第四章中实现的层,来构建神经网络。 1.神经网络学习全貌图 1).前提: 神经网络存在合适的权重和偏置,调整权重和偏置以便拟合训练数据的过程称…...

IB学习者的培养目标有哪些?
IB课程强调要培养年轻人的探究精神,在富有渊博知识的同时,更要勤于思考,敢于思考,尊重和理解跨文化的差异,坚持原则维护公平,让这个世界充满爱与和平,让这个世界变得更加美好。上一次我们为大家…...

C++类基础(十三)
类的继承 ● 通过类的继承(派生)来引入“是一个”的关系( 17.2 — Basic inheritance in C) – 通常采用 public 继承( struct V.S. class ) – 注意:继承部分不是类的声明 – 使用基类的指针…...

03 OpenCV图像运算
文章目录1 普通加法1 加号相加2 add函数2 加权相加3 按位运算1 按位与运算2 按位或运算、非运算4 掩膜1 普通加法 1 加号相加 在 OpenCV 中,图像加法可以使用加号运算符()来实现。例如,如果要将两幅图像相加,可以使用…...

【C语言学习笔记】:动态库
一、动态库 通过之前静态库那篇文章的介绍。发现静态库更容易使用和理解,也达到了代码复用的目的,那为什么还需要动态库呢? 1、为什么还需要动态库? 为什么需要动态库,其实也是静态库的特点导致。 ▶ 空间浪费是静…...

Zookeeper
zookeeper是一个分布式协调服务。所谓分布式协调主要是来解决分布式系统中多个进程之间的同步限制,防止出现脏读,例如我们常说的分布式锁。 zookeeper中的数据是存储在内存当中的,因此它的效率十分高效。它内部的存储方式十分类似于文件存储…...

wav转mp3,wav转换成mp3教程
很多使用音频文件的小伙伴,总会接触到不同类型的音频格式,根据需求不同需要做相关的处理。比如有人接触到了wav格式的音频,这是windows系统研发的一种标准数字音频文件,是一种占用磁盘体积超级大的音频格式,通常用于录…...

springboot项目配置文件加密
1背景: springboot项目中要求不能采用明文密码,故采用配置文件加密. 目前采用有密码的有redis nacos rabbitmq mysql 这些配置文件 2技术 2.1 redis nacos rabbitmq 配置文件加密 采用加密方式是jasypt 加密 2.1.1 加密步骤 2.1.2 引入maven依赖 …...

公司招聘:33岁以上的和两年一跳的不要,开出工资我还以为看错了...
导读:对于公司来说,肯定是希望花最少的钱招到最优秀的员工,但事实上这个想法是不太现实的,虽然如今互联网不太好找工作,但要员工降薪去入职,相信还是有很大难度的,很多人宁可在家休息࿰…...
【置顶】:文章合集系列
【置顶】:文章合集系列 必看 文章中的所有内容仅供做个人学习使用,所有环境都在本地搭建并验证,任何人使用文中方法进行未经授权的渗透行为都与文章与我本人无关,请各位大佬不要进行未经授权的渗透行为…… 前言 之前更新过一段…...

Go的web开发Gin框架1(八)——Gin
一、重点内容: 知识要点有哪些? 1、了解Gin框架 2、导入使用Gin框架 3、尝试配合GORM开发 4、整合html,css,js 二、详细知识点介绍: 1、Gin框架介绍 Gin是一个golang的微框架,封装比较优雅&…...

吴思进——复杂美创始人首席执行官
杭州复杂美科技有限公司创始人兼CEO, 本科毕业于浙江大学机械专业,辅修过多门管理课程;1997年获经济学硕士学位,有关对冲基金的毕业论文被评为优秀;2008年创办杭州复杂美科技有限公司。 吴思进 中国电子学会区块链委员会专家&…...
apk简单介绍(组成以及打包安装流程)
apk简单介绍APK 的组成apk安装流程app的启动过程apk打包流程AIDLAIDL介绍为什么要设计这门语言它有哪些语法?默认支持的数据类型包括什么是apk打包流程了解打包流程能做什么操作APK 的组成 APK 其实是一个 zip 类型的压缩包,而一个典型的 APK 通常都会包…...
ffmpeg学习笔记之SDL视频播放器
看了雷神的 100行代码实现最简单的基于FFMPEGSDL的视频播放器(SDL1.x) 后手痒难耐,决定将里面的代码重新建一个 首先建立一个空项目,新建一个Mysimplest.cpp的文件。在里面写代码 #include <stdio.h>extern "C" …...

【Git】合并多条 commit 注释信息
文章目录1、查看 commit 记录2、合并 commit 注释1、查看 commit 记录 # 3 指的是查看最近 3 次的 commit 记录,如果要查看多次的可以修改数字 # -3 不加,则表示查看所有 commit 记录,一般还是用数字去指定 git log -32、合并 commit 注释 …...

【gcc/g++】程序的翻译(.c -->.exe)
环境:centos7.6,腾讯云服务器Linux文章都放在了专栏:【Linux】欢迎支持订阅🌹前言我们在写完代码运行时会发现生成了一个.exe的可执行程序,那么该程序是如何形成的呢?本次章节将在linux下用编译器gcc进行一…...

电话号码的字母组合-力扣17-java
一、题目描述给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。答案可以按 任意顺序 返回。给出数字到字母的映射如下(与电话按键相同)。注意 1 不对应任何字母。示例 1:输入:digits "23"输出…...

Archery-SQL审核查询平台
Archery-SQL审核查询平台 文章目录Archery-SQL审核查询平台一、功能列表介绍1.1、SQL审核MySQL实例非MySQL实例审核执行分离SQL工单自动审批、高危语句驳回快速上线其他实例定时执行1.2、SQL查询多类型数据库支持授权管理页面体验1.3、SQL优化慢日志管理SQL语句优化1.4、实例管…...

MySQL8.0安装教程
文章目录1.官网下载MySQL2.下载完记住解压的地址(一会用到)3.进入刚刚解压的文件夹下,创建data和my.ini在根目录下创建一个txt文件,名字叫my,文件后缀为ini,之后复制下面这个代码放在my.ini文件下ÿ…...

一文详解工业知识模型互联平台MoHub
1月8日,MWORKS 2023产品发布会落下帷幕。会上,同元软控隆重推出了云原生的工业知识模型互联平台MoHub,引起广泛关注。本文将从服务定位、架构方案、核心服务、持续运营等方面对MoHub平台进行全面介绍。1 MoHub平台的服务定位装备数字化的必要…...

第19节 Node.js Express 框架
Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...
PHP和Node.js哪个更爽?
先说结论,rust完胜。 php:laravel,swoole,webman,最开始在苏宁的时候写了几年php,当时觉得php真的是世界上最好的语言,因为当初活在舒适圈里,不愿意跳出来,就好比当初活在…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...

SpringCloudGateway 自定义局部过滤器
场景: 将所有请求转化为同一路径请求(方便穿网配置)在请求头内标识原来路径,然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...

【JVM】Java虚拟机(二)——垃圾回收
目录 一、如何判断对象可以回收 (一)引用计数法 (二)可达性分析算法 二、垃圾回收算法 (一)标记清除 (二)标记整理 (三)复制 (四ÿ…...
探索Selenium:自动化测试的神奇钥匙
目录 一、Selenium 是什么1.1 定义与概念1.2 发展历程1.3 功能概述 二、Selenium 工作原理剖析2.1 架构组成2.2 工作流程2.3 通信机制 三、Selenium 的优势3.1 跨浏览器与平台支持3.2 丰富的语言支持3.3 强大的社区支持 四、Selenium 的应用场景4.1 Web 应用自动化测试4.2 数据…...

MySQL的pymysql操作
本章是MySQL的最后一章,MySQL到此完结,下一站Hadoop!!! 这章很简单,完整代码在最后,详细讲解之前python课程里面也有,感兴趣的可以往前找一下 一、查询操作 我们需要打开pycharm …...

相关类相关的可视化图像总结
目录 一、散点图 二、气泡图 三、相关图 四、热力图 五、二维密度图 六、多模态二维密度图 七、雷达图 八、桑基图 九、总结 一、散点图 特点 通过点的位置展示两个连续变量之间的关系,可直观判断线性相关、非线性相关或无相关关系,点的分布密…...
CppCon 2015 学习:REFLECTION TECHNIQUES IN C++
关于 Reflection(反射) 这个概念,总结一下: Reflection(反射)是什么? 反射是对类型的自我检查能力(Introspection) 可以查看类的成员变量、成员函数等信息。反射允许枚…...

基于stm32F10x 系列微控制器的智能电子琴(附完整项目源码、详细接线及讲解视频)
注:文章末尾网盘链接中自取成品使用演示视频、项目源码、项目文档 所用硬件:STM32F103C8T6、无源蜂鸣器、44矩阵键盘、flash存储模块、OLED显示屏、RGB三色灯、面包板、杜邦线、usb转ttl串口 stm32f103c8t6 面包板 …...