第四章.误差反向传播法—误差反向传播法实现手写数字识别神经网络
第四章.误差反向传播法
4.3 误差反向传播法实现手写数字识别神经网络
通过像组装乐高积木一样组装第四章中实现的层,来构建神经网络。
1.神经网络学习全貌图
1).前提:
- 神经网络存在合适的权重和偏置,调整权重和偏置以便拟合训练数据的过程称为“学习”,神经网络的学习分成下面4个步骤。
2).步骤1 (mini-batch):
- 从训练数据中随机选出一部分数据,这部分数据称为mini-batch,我们的目标是减少mini-batch损失函数的值。
3).步骤2 (计算梯度):
- 为了减少mini_batch损失函数的值,需要求出各个权重参数的梯度,梯度表示损失函数的值减少最多的方向。
4).步骤3 (更新参数):
- 将权重参数沿梯度方向进行微小更新
5).步骤4 (重复):
- 重复步骤1,步骤2,步骤3
2.手写数字识别神经网络的实现:(2层)
# 误差反向传播法实现手写数字识别神经网络import numpy as np
import matplotlib.pyplot as plt
import sys, ossys.path.append(os.pardir)
from dataset.mnist import load_mnist
from collections import OrderedDictclass Affine:def __init__(self, W, b):self.W = Wself.b = bself.x = Noneself.original_x_shape = None# 权重和偏置参数的导数self.dW = Noneself.db = None# 向前传播def forward(self, x):self.original_x_shape = x.shapex = x.reshape(x.shape[0], -1)self.x = xout = np.dot(self.x, self.W) + self.breturn out# 反向传播def backward(self, dout):dx = np.dot(dout, self.W.T)self.dW = np.dot(self.x.T, dout)self.db = np.sum(dout, axis=0)dx = dx.reshape(*self.original_x_shape) # 还原输入数据的形状(对应张量)return dxclass ReLU:def __init__(self):self.mask = Nonedef forward(self, x):self.mask = (x <= 0)out = x.copy()out[self.mask] = 0return outdef backward(self, dout):dout[self.mask] = 0dx = doutreturn dxclass SoftmaxWithLoss:def __init__(self):self.loss = Noneself.y = Noneself.t = None# 输出层函数:softmaxdef softmax(self, x):if x.ndim == 2:x = x.Tx = x - np.max(x, axis=0)y = np.exp(x) / np.sum(np.exp(x), axis=0)return y.Tx = x - np.max(x) # 溢出对策y = np.exp(x) / np.sum(np.exp(x))return y# 误差函数:交叉熵误差def cross_entropy_error(self, y, t):if y.ndim == 1:y = y.reshape(1, y.size)t = t.reshape(1, t.size)# 监督数据是one_hot_label的情况下,转换为正确解标签的索引if t.size == y.size:t = t.argmax(axis=1)batch_size = y.shape[0]return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_sizedef forward(self, x, t):self.t = tself.y = self.softmax(x)self.loss = self.cross_entropy_error(self.y, self.t)return self.lossdef backward(self, dout=1):batch_size = self.t.shape[0]if self.t.size == self.y.size:dx = (self.y - self.t) / batch_sizeelse:dx = self.y.copy()dx[np.arange(batch_size), self.t] -= 1dx = dx / batch_sizereturn dxclass TwoLayerNet:# 初始化def __init__(self, input_size, hidden_size, output_size, weight_init_std=0.01):# 初始化权重self.params = {}self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)self.params['b1'] = np.zeros(hidden_size)self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)self.params['b2'] = np.zeros(output_size)# 生成层self.layers = OrderedDict()self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])self.layers['ReLU'] = ReLU()self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])self.lastLayer = SoftmaxWithLoss()def predict(self, x):for layer in self.layers.values():x = layer.forward(x)return xdef loss(self, x, t):y = self.predict(x)loss = self.lastLayer.forward(y, t)return lossdef accuracy(self, x, t):y = self.predict(x)y = np.argmax(y, axis=1)if t.ndim != 1: t = np.argmax(t, axis=1)accuracy = np.sum(y == t) / float(t.shape[0])return accuracy# 微分函数def numerical_gradient1(self, f, x):h = 1e-4grad = np.zeros_like(x)it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])while not it.finished:idx = it.multi_indextmp_val = x[idx]x[idx] = float(tmp_val) + hfxh1 = f(x) # f(x+h)x[idx] = tmp_val - hfxh2 = f(x) # f(x-h)grad[idx] = (fxh1 - fxh2) / (2 * h)x[idx] = tmp_val # 还原值it.iternext()return grad# 通过数值微分计算关于权重参数的梯度def numerical_gradient(self, x, t):loss_W = lambda W: self.loss(x, t)grad = {}grad['W1'] = self.numerical_gradient1(loss_W, self.params['W1'])grad['b1'] = self.numerical_gradient1(loss_W, self.params['b1'])grad['W2'] = self.numerical_gradient1(loss_W, self.params['W2'])grad['b2'] = self.numerical_gradient1(loss_W, self.params['b2'])return grad# 通过误差反向传播法计算权重参数的梯度误差def gradient(self, x, t):# 正向传播self.loss(x, t)# 反向传播dout = 1dout = self.lastLayer.backward(dout)layers = list(self.layers.values())layers.reverse()for layer in layers:dout = layer.backward(dout)# 设定grads = {}grads['W1'] = self.layers['Affine1'].dWgrads['b1'] = self.layers['Affine1'].dbgrads['W2'] = self.layers['Affine2'].dWgrads['b2'] = self.layers['Affine2'].dbreturn grads# 读入数据
def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)return (x_train, t_train), (x_test, t_test)# 读入数据
(x_train, t_train), (x_test, t_test) = get_data()network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)iters_num = 10000
train_size = x_train.shape[0]
batch_size = 100
lr = 0.1
train_loss_list = []
train_acc_list = []
test_acc_list = []
iter_per_epoch = max(train_size / batch_size, 1)for i in range(iters_num):batch_mask = np.random.choice(train_size, batch_size)x_batch = x_train[batch_mask]t_batch = t_train[batch_mask]# 通过误差反向传播法求梯度grad = network.gradient(x_batch, t_batch)# 更新for key in ('W1', 'b1', 'W2', 'b2'):network.params[key] -= lr * grad[key]loss = network.loss(x_batch, t_batch)train_loss_list.append(loss)if i % iter_per_epoch == 0:train_acc = network.accuracy(x_train, t_train)train_acc_list.append(train_acc)test_acc = network.accuracy(x_test, t_test)test_acc_list.append(test_acc)print('train_acc,test_acc|', str(train_acc) + ',' + str(test_acc))# 绘制识别精度图像
plt.rcParams['font.sans-serif'] = ['SimHei'] # 解决中文乱码
plt.rcParams['axes.unicode_minus'] = False # 解决负号不显示的问题plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
x_data = np.arange(0, len(train_acc_list))
plt.plot(x_data, train_acc_list, 'b')
plt.plot(x_data, test_acc_list, 'r')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.ylim(0.0, 1.0)
plt.title('训练数据和测试数据的识别精度')
plt.legend(['train_acc', 'test_acc'])plt.subplot(1, 2, 2)
x_data = np.arange(0, len(train_loss_list))
plt.plot(x_data, train_loss_list, 'g')
plt.xlabel('iters_num')
plt.ylabel('loss')
plt.title('损失函数')
plt.show()
3.结果展示

相关文章:
第四章.误差反向传播法—误差反向传播法实现手写数字识别神经网络
第四章.误差反向传播法 4.3 误差反向传播法实现手写数字识别神经网络 通过像组装乐高积木一样组装第四章中实现的层,来构建神经网络。 1.神经网络学习全貌图 1).前提: 神经网络存在合适的权重和偏置,调整权重和偏置以便拟合训练数据的过程称…...
IB学习者的培养目标有哪些?
IB课程强调要培养年轻人的探究精神,在富有渊博知识的同时,更要勤于思考,敢于思考,尊重和理解跨文化的差异,坚持原则维护公平,让这个世界充满爱与和平,让这个世界变得更加美好。上一次我们为大家…...
C++类基础(十三)
类的继承 ● 通过类的继承(派生)来引入“是一个”的关系( 17.2 — Basic inheritance in C) – 通常采用 public 继承( struct V.S. class ) – 注意:继承部分不是类的声明 – 使用基类的指针…...
03 OpenCV图像运算
文章目录1 普通加法1 加号相加2 add函数2 加权相加3 按位运算1 按位与运算2 按位或运算、非运算4 掩膜1 普通加法 1 加号相加 在 OpenCV 中,图像加法可以使用加号运算符()来实现。例如,如果要将两幅图像相加,可以使用…...
【C语言学习笔记】:动态库
一、动态库 通过之前静态库那篇文章的介绍。发现静态库更容易使用和理解,也达到了代码复用的目的,那为什么还需要动态库呢? 1、为什么还需要动态库? 为什么需要动态库,其实也是静态库的特点导致。 ▶ 空间浪费是静…...
Zookeeper
zookeeper是一个分布式协调服务。所谓分布式协调主要是来解决分布式系统中多个进程之间的同步限制,防止出现脏读,例如我们常说的分布式锁。 zookeeper中的数据是存储在内存当中的,因此它的效率十分高效。它内部的存储方式十分类似于文件存储…...
wav转mp3,wav转换成mp3教程
很多使用音频文件的小伙伴,总会接触到不同类型的音频格式,根据需求不同需要做相关的处理。比如有人接触到了wav格式的音频,这是windows系统研发的一种标准数字音频文件,是一种占用磁盘体积超级大的音频格式,通常用于录…...
springboot项目配置文件加密
1背景: springboot项目中要求不能采用明文密码,故采用配置文件加密. 目前采用有密码的有redis nacos rabbitmq mysql 这些配置文件 2技术 2.1 redis nacos rabbitmq 配置文件加密 采用加密方式是jasypt 加密 2.1.1 加密步骤 2.1.2 引入maven依赖 …...
公司招聘:33岁以上的和两年一跳的不要,开出工资我还以为看错了...
导读:对于公司来说,肯定是希望花最少的钱招到最优秀的员工,但事实上这个想法是不太现实的,虽然如今互联网不太好找工作,但要员工降薪去入职,相信还是有很大难度的,很多人宁可在家休息࿰…...
【置顶】:文章合集系列
【置顶】:文章合集系列 必看 文章中的所有内容仅供做个人学习使用,所有环境都在本地搭建并验证,任何人使用文中方法进行未经授权的渗透行为都与文章与我本人无关,请各位大佬不要进行未经授权的渗透行为…… 前言 之前更新过一段…...
Go的web开发Gin框架1(八)——Gin
一、重点内容: 知识要点有哪些? 1、了解Gin框架 2、导入使用Gin框架 3、尝试配合GORM开发 4、整合html,css,js 二、详细知识点介绍: 1、Gin框架介绍 Gin是一个golang的微框架,封装比较优雅&…...
吴思进——复杂美创始人首席执行官
杭州复杂美科技有限公司创始人兼CEO, 本科毕业于浙江大学机械专业,辅修过多门管理课程;1997年获经济学硕士学位,有关对冲基金的毕业论文被评为优秀;2008年创办杭州复杂美科技有限公司。 吴思进 中国电子学会区块链委员会专家&…...
apk简单介绍(组成以及打包安装流程)
apk简单介绍APK 的组成apk安装流程app的启动过程apk打包流程AIDLAIDL介绍为什么要设计这门语言它有哪些语法?默认支持的数据类型包括什么是apk打包流程了解打包流程能做什么操作APK 的组成 APK 其实是一个 zip 类型的压缩包,而一个典型的 APK 通常都会包…...
ffmpeg学习笔记之SDL视频播放器
看了雷神的 100行代码实现最简单的基于FFMPEGSDL的视频播放器(SDL1.x) 后手痒难耐,决定将里面的代码重新建一个 首先建立一个空项目,新建一个Mysimplest.cpp的文件。在里面写代码 #include <stdio.h>extern "C" …...
【Git】合并多条 commit 注释信息
文章目录1、查看 commit 记录2、合并 commit 注释1、查看 commit 记录 # 3 指的是查看最近 3 次的 commit 记录,如果要查看多次的可以修改数字 # -3 不加,则表示查看所有 commit 记录,一般还是用数字去指定 git log -32、合并 commit 注释 …...
【gcc/g++】程序的翻译(.c -->.exe)
环境:centos7.6,腾讯云服务器Linux文章都放在了专栏:【Linux】欢迎支持订阅🌹前言我们在写完代码运行时会发现生成了一个.exe的可执行程序,那么该程序是如何形成的呢?本次章节将在linux下用编译器gcc进行一…...
电话号码的字母组合-力扣17-java
一、题目描述给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。答案可以按 任意顺序 返回。给出数字到字母的映射如下(与电话按键相同)。注意 1 不对应任何字母。示例 1:输入:digits "23"输出…...
Archery-SQL审核查询平台
Archery-SQL审核查询平台 文章目录Archery-SQL审核查询平台一、功能列表介绍1.1、SQL审核MySQL实例非MySQL实例审核执行分离SQL工单自动审批、高危语句驳回快速上线其他实例定时执行1.2、SQL查询多类型数据库支持授权管理页面体验1.3、SQL优化慢日志管理SQL语句优化1.4、实例管…...
MySQL8.0安装教程
文章目录1.官网下载MySQL2.下载完记住解压的地址(一会用到)3.进入刚刚解压的文件夹下,创建data和my.ini在根目录下创建一个txt文件,名字叫my,文件后缀为ini,之后复制下面这个代码放在my.ini文件下ÿ…...
一文详解工业知识模型互联平台MoHub
1月8日,MWORKS 2023产品发布会落下帷幕。会上,同元软控隆重推出了云原生的工业知识模型互联平台MoHub,引起广泛关注。本文将从服务定位、架构方案、核心服务、持续运营等方面对MoHub平台进行全面介绍。1 MoHub平台的服务定位装备数字化的必要…...
el-tabs报错Cannot read properties of null (reading ‘insertBefore‘)
使用elementui-plus的tabs组件在开发中遇到的一个问题,分析了代码,发现逻辑没有任何问题,但是点击tab切换就会报错:Uncaught (in promise) TypeError: Cannot read properties of null (reading insertBefore)调试发现parent参数是…...
OpenClaw终端整合:QwQ-32B命令行操作增强方案
OpenClaw终端整合:QwQ-32B命令行操作增强方案 1. 为什么需要终端智能助手 作为开发者,我们每天要处理大量命令行操作。从简单的目录跳转、文件操作,到复杂的管道命令组合,再到调试报错信息,这些重复性工作消耗了大量…...
Windows 内网 Web 服务穿透方案推荐
Windows 内网 Web 服务穿透方案推荐 面向场景:内网机器为 Windows,需从公网或外网访问内网 HTTP/HTTPS Web 服务;优先选择相对不易被误报、来源清晰、可审计的方案。 关于「报毒」的说明 穿透类软件常被启发式引擎标为「风险/可疑」…...
ESP32嵌入式系统设计与实现指南
1. 项目概述1.1 系统架构本项目基于ESP32主控芯片设计,采用模块化架构实现多功能嵌入式系统。系统包含以下核心模块:主控单元:ESP32-WROOM-32D模组电源管理:TPS63020升降压转换器传感器接口:I2C/SPI多协议兼容设计人机…...
pyautocad:实现AutoCAD自动化流程的创新方法
pyautocad:实现AutoCAD自动化流程的创新方法 【免费下载链接】pyautocad AutoCAD Automation for Python ⛺ 项目地址: https://gitcode.com/gh_mirrors/py/pyautocad pyautocad作为开发者必备的效率工具,通过Python语言与AutoCAD的ActiveX接口无…...
终极指南:从NumPy到Pydantic的Claude-Code-Usage-Monitor依赖管理完整解析
终极指南:从NumPy到Pydantic的Claude-Code-Usage-Monitor依赖管理完整解析 【免费下载链接】Claude-Code-Usage-Monitor Real-time Claude Code usage monitor with predictions and warnings 项目地址: https://gitcode.com/gh_mirrors/cl/Claude-Code-Usage-Mon…...
图床项目(二) 接口设计
接口设计 1 . muduo 网络模型 该模型相较于普通的reactor模型复杂一点,其中包括mainReactor 和 多个 subReactor ,其中每一个 subReactor对应一个线程。 其中 mainReactor 负责处理新连接 , 并将连接均匀分配给 subReactor ,后续…...
Maxwell16.0实战:如何用实验电流数据搞定电机仿真(附.tab文件制作技巧)
Maxwell16.0实战:实验电流数据驱动电机仿真的全流程解析 电机仿真作为现代工业设计的重要环节,其准确性直接影响产品性能评估。而将实测电流数据融入仿真流程,往往是工程师突破"理想模型"局限的关键一步。本文将系统性地拆解从实验…...
League-Toolkit英雄联盟辅助工具完全指南:从配置到精通的高效使用手册
League-Toolkit英雄联盟辅助工具完全指南:从配置到精通的高效使用手册 【免费下载链接】League-Toolkit 兴趣使然的、简单易用的英雄联盟工具集。支持战绩查询、自动秒选等功能。基于 LCU API。 项目地址: https://gitcode.com/gh_mirrors/le/League-Toolkit …...
保姆级教程:在Jeecg-Vue3项目中快速集成SuperQuery高级查询组件(含完整配置代码)
Jeecg-Vue3项目实战:SuperQuery高级查询组件深度集成指南 在后台管理系统开发中,高效的数据筛选功能直接影响用户体验和操作效率。Jeecg-Vue3作为企业级快速开发框架,其内置的SuperQuery组件能够帮助开发者快速构建复杂的多条件查询面板。本文…...
