当前位置: 首页 > news >正文

python3+TensorFlow 2.x(四)反向传播

目录

反向传播算法

反向传播算法基本步骤:

反向中的参数变化

总结


反向传播算法

反向传播算法(Backpropagation)是训练人工神经网络时使用的一个重要算法,它是通过计算梯度并优化神经网络的权重来最小化误差。反向传播算法的核心是基于链式法则的梯度下降优化方法,通过计算误差对每个权重的偏导数来更新网络中的参数。

反向传播算法基本步骤:

前向传播:将输入数据传递通过神经网络的各层,计算每一层的输出。
计算损失:根据输出和实际标签计算损失(通常使用均方误差或交叉熵等作为损失函数)。
反向传播:根据损失函数对每个参数(如权重、偏置)计算梯度。梯度的计算通过链式法则进行反向传播,直到达到输入层。
更新权重:使用梯度下降算法来更新每一层的权重和偏置,使得损失函数最小化。

链式推到:https://blog.csdn.net/dingyahui123/category_6945552.html?spm=1001.2014.3001.5482

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 归一化数据并将其形状调整为 (N, 784),因为每张图片是 28x28 像素
train_images = train_images.reshape(-1, 28*28) / 255.0
test_images = test_images.reshape(-1, 28*28) / 255.0# 转换标签为 one-hot 编码
train_labels = np.eye(10)[train_labels]
test_labels = np.eye(10)[test_labels]
# 定义激活函数
def sigmoid(x):return 1 / (1 + np.exp(-x))# 定义激活函数的导数
def sigmoid_derivative(x):return x * (1 - x)# 网络架构参数
input_size = 28 * 28  # 输入层的大小
hidden_size = 128     # 隐藏层的大小
output_size = 10      # 输出层的大小# 初始化权重和偏置
W1 = np.random.randn(input_size, hidden_size)  # 输入层到隐藏层的权重
b1 = np.zeros((1, hidden_size))  # 隐藏层的偏置
W2 = np.random.randn(hidden_size, output_size)  # 隐藏层到输出层的权重
b2 = np.zeros((1, output_size))  # 输出层的偏置
# 设置超参数
epochs = 20
learning_rate = 0.1
batch_size = 64# 训练过程
for epoch in range(epochs):for i in range(0, len(train_images), batch_size):# 选择当前batch的数据X_batch = train_images[i:i+batch_size]y_batch = train_labels[i:i+batch_size]# 前向传播z1 = np.dot(X_batch, W1) + b1a1 = sigmoid(z1)z2 = np.dot(a1, W2) + b2a2 = sigmoid(z2)# 计算损失的梯度output_error = a2 - y_batch  # 损失函数的梯度output_delta = output_error * sigmoid_derivative(a2)hidden_error = output_delta.dot(W2.T)hidden_delta = hidden_error * sigmoid_derivative(a1)# 更新权重和偏置W2 -= learning_rate * a1.T.dot(output_delta)b2 -= learning_rate * np.sum(output_delta, axis=0, keepdims=True)W1 -= learning_rate * X_batch.T.dot(hidden_delta)b1 -= learning_rate * np.sum(hidden_delta, axis=0, keepdims=True)# 每10轮输出一次损失if epoch % 10 == 0:loss = np.mean(np.square(a2 - y_batch))print(f"Epoch {epoch}, Loss: {loss}")
# 测试模型
z1 = np.dot(test_images, W1) + b1
a1 = sigmoid(z1)
z2 = np.dot(a1, W2) + b2
a2 = sigmoid(z2)# 计算准确率
predictions = np.argmax(a2, axis=1)
true_labels = np.argmax(test_labels, axis=1)
accuracy = np.mean(predictions == true_labels)print(f"Test Accuracy: {accuracy * 100:.2f}%")
# 可视化前5个测试图像及其预测结果
for i in range(5):plt.imshow(test_images[i].reshape(28, 28), cmap='gray')plt.title(f"Predicted: {predictions[i]}, Actual: {true_labels[i]}")plt.show()

 

反向中的参数变化

import numpy as np
import matplotlib.pyplot as plt
import imageio# 激活函数和其导数
def sigmoid(x):return 1 / (1 + np.exp(-x))def sigmoid_derivative(x):return x * (1 - x)# 生成一些示例数据
np.random.seed(0)
X = np.array([[0, 0],[0, 1],[1, 0],[1, 1]])
y = np.array([[0], [1], [1], [0]])  # XOR 问题# 初始化参数
input_layer_neurons = 2
hidden_layer_neurons = 2
output_neurons = 1
learning_rate = 0.5
epochs = 10000# 初始化权重
weights_input_hidden = np.random.uniform(size=(input_layer_neurons, hidden_layer_neurons))
weights_hidden_output = np.random.uniform(size=(hidden_layer_neurons, output_neurons))# 存储权重和图像
weights_history = []
losses = []
images = []# 训练过程
for epoch in range(epochs):# 前向传播hidden_layer_input = np.dot(X, weights_input_hidden)hidden_layer_output = sigmoid(hidden_layer_input)output_layer_input = np.dot(hidden_layer_output, weights_hidden_output)predicted_output = sigmoid(output_layer_input)loss = np.mean((y - predicted_output) ** 2)losses.append(loss)# 反向传播error = y - predicted_outputd_predicted_output = error * sigmoid_derivative(predicted_output)error_hidden_layer = d_predicted_output.dot(weights_hidden_output.T)d_hidden_layer = error_hidden_layer * sigmoid_derivative(hidden_layer_output)# 更新权重weights_hidden_output += hidden_layer_output.T.dot(d_predicted_output) * learning_rateweights_input_hidden += X.T.dot(d_hidden_layer) * learning_rate# 保存权重weights_history.append((weights_input_hidden.copy(), weights_hidden_output.copy()))# 每1000次迭代保存一次图像if epoch % 1000 == 0:plt.figure(figsize=(8, 6))plt.subplot(1, 2, 1)plt.title('Weights Input-Hidden')plt.imshow(weights_input_hidden, cmap='viridis', aspect='auto')plt.colorbar()plt.subplot(1, 2, 2)plt.title('Weights Hidden-Output')plt.imshow(weights_hidden_output, cmap='viridis', aspect='auto')plt.colorbar()# 保存图像plt.savefig(f'weights_epoch_{epoch}.png')plt.close()if epoch % 1000 == 0:plt.figure(figsize=(8, 6))plt.plot(losses, label='Loss')plt.title('Loss over epochs')plt.xlabel('Epochs')plt.ylabel('Loss')plt.xlim(0, epochs)plt.ylim(0, np.max(losses))plt.grid()plt.legend()# 保存图像plt.savefig(f'loss_epoch_{epoch}.png')plt.close()
# 创建 GIF
with imageio.get_writer('weights_update.gif', mode='I', duration=0.5) as writer:for epoch in range(0, epochs, 1000):image = imageio.imread(f'weights_epoch_{epoch}.png')writer.append_data(image)
# 创建 GIF
with imageio.get_writer('training_loss.gif', mode='I', duration=0.5) as writer:for epoch in range(0, epochs, 1000):image = imageio.imread(f'loss_epoch_{epoch}.png')writer.append_data(image)
# 清理生成的图像文件
import os
for epoch in range(0, epochs, 1000):os.remove(f'weights_epoch_{epoch}.png')os.remove(f'loss_epoch_{epoch}.png')print("GIF 已生成:training_loss.gif")
print("GIF 已生成:weights_update.gif")

 

总结

反向传播算法是神经网络训练中的核心技术,它通过计算损失函数相对于每个权重和偏置的梯度,利用梯度下降算法优化网络的参数。理解了反向传播的基本过程,可以进一步扩展到更复杂的网络结构,如卷积神经网络(CNN)和循环神经网络(RNN)。

相关文章:

python3+TensorFlow 2.x(四)反向传播

目录 反向传播算法 反向传播算法基本步骤: 反向中的参数变化 总结 反向传播算法 反向传播算法(Backpropagation)是训练人工神经网络时使用的一个重要算法,它是通过计算梯度并优化神经网络的权重来最小化误差。反向传播算法的核…...

Flutter 使用 flutter_inappwebview 加载 App 本地 HTML 文件

在 Flutter 开发中,加载本地 HTML 文件是一个常见的需求,尤其是在需要展示离线内容或自定义页面时。flutter_inappwebview 是一个功能强大的插件,支持加载本地文件和网络资源。本文将详细介绍如何使用 flutter_inappwebview 加载 App 本地 HT…...

Word常见问题:嵌入图片无法显示完整

场景:在Word中,嵌入式图片显示不全,一部分图片在文字下方。如: 问题原因:因段落行距导致 方法一 快捷方式 选中图片,通过"ctrl1"快捷调整为1倍行距 方法二 通过工具栏调整 选中图片&#xff0…...

为AI聊天工具添加一个知识系统 之68 详细设计 之9 三种中台和时间度量 之1

本文要点 要点 在维度0上 被分离出来 的业务中台 需求、技术中台要求、和数据中台请求 (分别在时间层/空间层/时空层上 对应一个不同种类槽的容器,分别表示业务特征Feature[3]/技术方面Aspect[3]/数据流Fluent[3]) 在维度1~3的运动过程中 从…...

On to OpenGL and 3D computer graphics

2. On to OpenGL and 3D computer graphics 声明:该代码来自:Computer Graphics Through OpenGL From Theory to Experiments,仅用作学习参考 2.1 First Program Square.cpp完整代码 /// // square.cpp // // OpenGL program to draw a squ…...

从曾国藩的经历看如何打破成长中的瓶颈

《曾国藩传》是一部充满智慧与人生哲理的传记,而曾国藩本人更是一个从“最笨”到“最智慧”的奇人。看他的成长与蜕变,不仅能感受到他如何超越自己的局限,也能从中获得关于人性、社会和历史的重要启示。曾国藩的一生让人深思,正是…...

JavaWeb学习-SpringBotWeb开发入门(HTTP协议)

(一)SpringBotWeb开发步骤 (1)创建springboot工程,并勾选开发相关依赖 (2)定义HelloController类,添加方法hello,并添加注解 (3)运行测试 (二)HTTP入门概述 创建请求页面 package com.itheima.demo3; /*请求处理类,加上注解标识为请求处理类*/import org.spr…...

数据库用户管理

数据库用户管理 1.创建用户 MySQL在安装是,会默认创建一个名位root的用户,该用户拥有超级权限,可以控制整个MySQL服务器。 在对MySQL的日常管理和操作中,通常创建一些具有适当权限的用户,尽可能的不用或少用root登录…...

BGP边界网关协议(Border Gateway Protocol)路由聚合详解

一、路由聚合 1、意义 在大规模的网络中,BGP路由表十分庞大,给设备造成了很大的负担,同时使发生路由振荡的几率也大大增加,影响网络的稳定性。 路由聚合是将多条路由合并的机制,它通过只向对等体发送聚合后的路由而…...

ASP.NET Core WebAPI的异步及返回值

目录 Action方法的异步 Action方法参数 捕捉URL占位符 捕捉QueryString的值 JSON报文体 其他方式 Action方法的异步 Action方法既可以同步也可以异步。异步Action方法的名字一般不需要以Async结尾。Web API中Action方法的返回值如果是普通数据类型,那么返回值…...

「 机器人 」仿生扑翼飞行器中的“被动旋转机制”概述

前言 在仿生扑翼飞行器的机翼设计中,模仿昆虫翼的被动旋转机制是一项关键技术。其核心思想在于:机翼旋转角度(攻角)并非完全通过主动伺服来控制,而是利用空气动力和惯性力的作用,自然地实现被动调节。以下对这种设计的背景、原理与优势进行详细说明。 1. 背景:昆虫的被动…...

「 机器人 」扑翼飞行器的数据驱动建模核心方法

前言 数据驱动建模可充分利用扑翼飞行器的已有运行数据,改进动力学模型与控制策略,并对未建模动态做出更精确的预测。在复杂的非线性飞行环境中,该方法能有效弥补传统解析建模的不足,具有较高的研究与应用价值。以下针对主要研究方向和实现步骤进行整理与阐述。 1. 数据驱动…...

个人网站搭建

搭建 LNMP环境搭建: LNMP环境指:Linux Nginx MySQL/MariaDB PHP,在debian上安装整体需要300MB的磁盘空间。MariaDB 是 MySQL 的一个分支,由 MySQL 的原开发者维护,通常在性能和优化上有所改进。由于其轻量化和与M…...

飞书项目流程入门指导手册

飞书项目流程入门指导手册 参考资料准备工作新建空间国际化配置新建工作项字段管理新建字段对接标识授权角色 流程管理基础说明流程节点配置流程节点的布局配置页面上布局按钮布局配置 流程节点驳回流程图展示自动化字段修改 局限性 参考资料 飞书官方参考文档:飞书…...

xss靶场

xss-labs下载地址&#xff1a;GitHub - do0dl3/xss-labs: xss 跨站漏洞平台 xss常见触发标签&#xff1a;XSS跨站脚本攻击实例与防御策略-CSDN博客 level-1 首先查看网页的源代码发现get传参的name的值test插入了html里头&#xff0c;还回显了payload的长度。 <!DOCTYPE …...

XML实体注入漏洞攻与防

JAVA中的XXE攻防 回显型 无回显型 cve-2014-3574...

switch组件的功能与用法

文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了PageView这个Widget,本章回中将介绍Switch Widget.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1 概念介绍 我们在这里介绍的Switch是指左右滑动的开关&#xff0c;常用来表示某项设置是打开还是关闭。Fl…...

cursor重构谷粒商城05——docker容器化技术快速入门【番外篇】

前言&#xff1a;这个系列将使用最前沿的cursor作为辅助编程工具&#xff0c;来快速开发一些基础的编程项目。目的是为了在真实项目中&#xff0c;帮助初级程序员快速进阶&#xff0c;以最快的速度&#xff0c;效率&#xff0c;快速进阶到中高阶程序员。 本项目将基于谷粒商城…...

高等数学学习笔记 ☞ 微分方程

1. 微分方程的基本概念 1. 微分方程的基本概念&#xff1a; &#xff08;1&#xff09;微分方程&#xff1a;含有未知函数及其导数或微分的方程。 举例说明微分方程&#xff1a;&#xff1b;。 &#xff08;2&#xff09;微分方程的阶&#xff1a;指微分方程中未知函数的导数…...

【探索 Kali Linux】渗透测试与网络安全的终极操作系统

探索 Kali Linux&#xff1a;渗透测试与网络安全的终极操作系统 在网络安全领域&#xff0c;Kali Linux 无疑是最受欢迎的操作系统之一。无论是专业的渗透测试人员、安全研究人员&#xff0c;还是对网络安全感兴趣的初学者&#xff0c;Kali Linux 都提供了强大的工具和灵活的环…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…...

oracle与MySQL数据库之间数据同步的技术要点

Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异&#xff0c;它们的数据同步要求既要保持数据的准确性和一致性&#xff0c;又要处理好性能问题。以下是一些主要的技术要点&#xff1a; 数据结构差异 数据类型差异&#xff…...

华为云Flexus+DeepSeek征文|DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建

华为云FlexusDeepSeek征文&#xff5c;DeepSeek-V3/R1 商用服务开通全流程与本地部署搭建 前言 如今大模型其性能出色&#xff0c;华为云 ModelArts Studio_MaaS大模型即服务平台华为云内置了大模型&#xff0c;能助力我们轻松驾驭 DeepSeek-V3/R1&#xff0c;本文中将分享如何…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

NPOI操作EXCEL文件 ——CAD C# 二次开发

缺点:dll.版本容易加载错误。CAD加载插件时&#xff0c;没有加载所有类库。插件运行过程中用到某个类库&#xff0c;会从CAD的安装目录找&#xff0c;找不到就报错了。 【方案2】让CAD在加载过程中把类库加载到内存 【方案3】是发现缺少了哪个库&#xff0c;就用插件程序加载进…...

为什么要创建 Vue 实例

核心原因:Vue 需要一个「控制中心」来驱动整个应用 你可以把 Vue 实例想象成你应用的**「大脑」或「引擎」。它负责协调模板、数据、逻辑和行为,将它们变成一个活的、可交互的应用**。没有这个实例,你的代码只是一堆静态的 HTML、JavaScript 变量和函数,无法「活」起来。 …...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例&#xff0c;Webpack.config.js它可能的配置和含义如下&#xff1a; 前言 Module Federation 的Webpack.config.js核心配置包括&#xff1a; name filename&#xff08;定义应用标识&#xff09; remotes&#xff08;引用远程模块&#xff0…...

WEB3全栈开发——面试专业技能点P7前端与链上集成

一、Next.js技术栈 ✅ 概念介绍 Next.js 是一个基于 React 的 服务端渲染&#xff08;SSR&#xff09;与静态网站生成&#xff08;SSG&#xff09; 框架&#xff0c;由 Vercel 开发。它简化了构建生产级 React 应用的过程&#xff0c;并内置了很多特性&#xff1a; ✅ 文件系…...