python3+TensorFlow 2.x(四)反向传播
目录
反向传播算法
反向传播算法基本步骤:
反向中的参数变化
总结
反向传播算法
反向传播算法(Backpropagation)是训练人工神经网络时使用的一个重要算法,它是通过计算梯度并优化神经网络的权重来最小化误差。反向传播算法的核心是基于链式法则的梯度下降优化方法,通过计算误差对每个权重的偏导数来更新网络中的参数。
反向传播算法基本步骤:
前向传播:将输入数据传递通过神经网络的各层,计算每一层的输出。
计算损失:根据输出和实际标签计算损失(通常使用均方误差或交叉熵等作为损失函数)。
反向传播:根据损失函数对每个参数(如权重、偏置)计算梯度。梯度的计算通过链式法则进行反向传播,直到达到输入层。
更新权重:使用梯度下降算法来更新每一层的权重和偏置,使得损失函数最小化。
链式推到:https://blog.csdn.net/dingyahui123/category_6945552.html?spm=1001.2014.3001.5482
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 归一化数据并将其形状调整为 (N, 784),因为每张图片是 28x28 像素
train_images = train_images.reshape(-1, 28*28) / 255.0
test_images = test_images.reshape(-1, 28*28) / 255.0# 转换标签为 one-hot 编码
train_labels = np.eye(10)[train_labels]
test_labels = np.eye(10)[test_labels]
# 定义激活函数
def sigmoid(x):return 1 / (1 + np.exp(-x))# 定义激活函数的导数
def sigmoid_derivative(x):return x * (1 - x)# 网络架构参数
input_size = 28 * 28 # 输入层的大小
hidden_size = 128 # 隐藏层的大小
output_size = 10 # 输出层的大小# 初始化权重和偏置
W1 = np.random.randn(input_size, hidden_size) # 输入层到隐藏层的权重
b1 = np.zeros((1, hidden_size)) # 隐藏层的偏置
W2 = np.random.randn(hidden_size, output_size) # 隐藏层到输出层的权重
b2 = np.zeros((1, output_size)) # 输出层的偏置
# 设置超参数
epochs = 20
learning_rate = 0.1
batch_size = 64# 训练过程
for epoch in range(epochs):for i in range(0, len(train_images), batch_size):# 选择当前batch的数据X_batch = train_images[i:i+batch_size]y_batch = train_labels[i:i+batch_size]# 前向传播z1 = np.dot(X_batch, W1) + b1a1 = sigmoid(z1)z2 = np.dot(a1, W2) + b2a2 = sigmoid(z2)# 计算损失的梯度output_error = a2 - y_batch # 损失函数的梯度output_delta = output_error * sigmoid_derivative(a2)hidden_error = output_delta.dot(W2.T)hidden_delta = hidden_error * sigmoid_derivative(a1)# 更新权重和偏置W2 -= learning_rate * a1.T.dot(output_delta)b2 -= learning_rate * np.sum(output_delta, axis=0, keepdims=True)W1 -= learning_rate * X_batch.T.dot(hidden_delta)b1 -= learning_rate * np.sum(hidden_delta, axis=0, keepdims=True)# 每10轮输出一次损失if epoch % 10 == 0:loss = np.mean(np.square(a2 - y_batch))print(f"Epoch {epoch}, Loss: {loss}")
# 测试模型
z1 = np.dot(test_images, W1) + b1
a1 = sigmoid(z1)
z2 = np.dot(a1, W2) + b2
a2 = sigmoid(z2)# 计算准确率
predictions = np.argmax(a2, axis=1)
true_labels = np.argmax(test_labels, axis=1)
accuracy = np.mean(predictions == true_labels)print(f"Test Accuracy: {accuracy * 100:.2f}%")
# 可视化前5个测试图像及其预测结果
for i in range(5):plt.imshow(test_images[i].reshape(28, 28), cmap='gray')plt.title(f"Predicted: {predictions[i]}, Actual: {true_labels[i]}")plt.show()
反向中的参数变化
import numpy as np
import matplotlib.pyplot as plt
import imageio# 激活函数和其导数
def sigmoid(x):return 1 / (1 + np.exp(-x))def sigmoid_derivative(x):return x * (1 - x)# 生成一些示例数据
np.random.seed(0)
X = np.array([[0, 0],[0, 1],[1, 0],[1, 1]])
y = np.array([[0], [1], [1], [0]]) # XOR 问题# 初始化参数
input_layer_neurons = 2
hidden_layer_neurons = 2
output_neurons = 1
learning_rate = 0.5
epochs = 10000# 初始化权重
weights_input_hidden = np.random.uniform(size=(input_layer_neurons, hidden_layer_neurons))
weights_hidden_output = np.random.uniform(size=(hidden_layer_neurons, output_neurons))# 存储权重和图像
weights_history = []
losses = []
images = []# 训练过程
for epoch in range(epochs):# 前向传播hidden_layer_input = np.dot(X, weights_input_hidden)hidden_layer_output = sigmoid(hidden_layer_input)output_layer_input = np.dot(hidden_layer_output, weights_hidden_output)predicted_output = sigmoid(output_layer_input)loss = np.mean((y - predicted_output) ** 2)losses.append(loss)# 反向传播error = y - predicted_outputd_predicted_output = error * sigmoid_derivative(predicted_output)error_hidden_layer = d_predicted_output.dot(weights_hidden_output.T)d_hidden_layer = error_hidden_layer * sigmoid_derivative(hidden_layer_output)# 更新权重weights_hidden_output += hidden_layer_output.T.dot(d_predicted_output) * learning_rateweights_input_hidden += X.T.dot(d_hidden_layer) * learning_rate# 保存权重weights_history.append((weights_input_hidden.copy(), weights_hidden_output.copy()))# 每1000次迭代保存一次图像if epoch % 1000 == 0:plt.figure(figsize=(8, 6))plt.subplot(1, 2, 1)plt.title('Weights Input-Hidden')plt.imshow(weights_input_hidden, cmap='viridis', aspect='auto')plt.colorbar()plt.subplot(1, 2, 2)plt.title('Weights Hidden-Output')plt.imshow(weights_hidden_output, cmap='viridis', aspect='auto')plt.colorbar()# 保存图像plt.savefig(f'weights_epoch_{epoch}.png')plt.close()if epoch % 1000 == 0:plt.figure(figsize=(8, 6))plt.plot(losses, label='Loss')plt.title('Loss over epochs')plt.xlabel('Epochs')plt.ylabel('Loss')plt.xlim(0, epochs)plt.ylim(0, np.max(losses))plt.grid()plt.legend()# 保存图像plt.savefig(f'loss_epoch_{epoch}.png')plt.close()
# 创建 GIF
with imageio.get_writer('weights_update.gif', mode='I', duration=0.5) as writer:for epoch in range(0, epochs, 1000):image = imageio.imread(f'weights_epoch_{epoch}.png')writer.append_data(image)
# 创建 GIF
with imageio.get_writer('training_loss.gif', mode='I', duration=0.5) as writer:for epoch in range(0, epochs, 1000):image = imageio.imread(f'loss_epoch_{epoch}.png')writer.append_data(image)
# 清理生成的图像文件
import os
for epoch in range(0, epochs, 1000):os.remove(f'weights_epoch_{epoch}.png')os.remove(f'loss_epoch_{epoch}.png')print("GIF 已生成:training_loss.gif")
print("GIF 已生成:weights_update.gif")
总结
反向传播算法是神经网络训练中的核心技术,它通过计算损失函数相对于每个权重和偏置的梯度,利用梯度下降算法优化网络的参数。理解了反向传播的基本过程,可以进一步扩展到更复杂的网络结构,如卷积神经网络(CNN)和循环神经网络(RNN)。
相关文章:

python3+TensorFlow 2.x(四)反向传播
目录 反向传播算法 反向传播算法基本步骤: 反向中的参数变化 总结 反向传播算法 反向传播算法(Backpropagation)是训练人工神经网络时使用的一个重要算法,它是通过计算梯度并优化神经网络的权重来最小化误差。反向传播算法的核…...
Flutter 使用 flutter_inappwebview 加载 App 本地 HTML 文件
在 Flutter 开发中,加载本地 HTML 文件是一个常见的需求,尤其是在需要展示离线内容或自定义页面时。flutter_inappwebview 是一个功能强大的插件,支持加载本地文件和网络资源。本文将详细介绍如何使用 flutter_inappwebview 加载 App 本地 HT…...

Word常见问题:嵌入图片无法显示完整
场景:在Word中,嵌入式图片显示不全,一部分图片在文字下方。如: 问题原因:因段落行距导致 方法一 快捷方式 选中图片,通过"ctrl1"快捷调整为1倍行距 方法二 通过工具栏调整 选中图片࿰…...
为AI聊天工具添加一个知识系统 之68 详细设计 之9 三种中台和时间度量 之1
本文要点 要点 在维度0上 被分离出来 的业务中台 需求、技术中台要求、和数据中台请求 (分别在时间层/空间层/时空层上 对应一个不同种类槽的容器,分别表示业务特征Feature[3]/技术方面Aspect[3]/数据流Fluent[3]) 在维度1~3的运动过程中 从…...

On to OpenGL and 3D computer graphics
2. On to OpenGL and 3D computer graphics 声明:该代码来自:Computer Graphics Through OpenGL From Theory to Experiments,仅用作学习参考 2.1 First Program Square.cpp完整代码 /// // square.cpp // // OpenGL program to draw a squ…...
从曾国藩的经历看如何打破成长中的瓶颈
《曾国藩传》是一部充满智慧与人生哲理的传记,而曾国藩本人更是一个从“最笨”到“最智慧”的奇人。看他的成长与蜕变,不仅能感受到他如何超越自己的局限,也能从中获得关于人性、社会和历史的重要启示。曾国藩的一生让人深思,正是…...
JavaWeb学习-SpringBotWeb开发入门(HTTP协议)
(一)SpringBotWeb开发步骤 (1)创建springboot工程,并勾选开发相关依赖 (2)定义HelloController类,添加方法hello,并添加注解 (3)运行测试 (二)HTTP入门概述 创建请求页面 package com.itheima.demo3; /*请求处理类,加上注解标识为请求处理类*/import org.spr…...
数据库用户管理
数据库用户管理 1.创建用户 MySQL在安装是,会默认创建一个名位root的用户,该用户拥有超级权限,可以控制整个MySQL服务器。 在对MySQL的日常管理和操作中,通常创建一些具有适当权限的用户,尽可能的不用或少用root登录…...

BGP边界网关协议(Border Gateway Protocol)路由聚合详解
一、路由聚合 1、意义 在大规模的网络中,BGP路由表十分庞大,给设备造成了很大的负担,同时使发生路由振荡的几率也大大增加,影响网络的稳定性。 路由聚合是将多条路由合并的机制,它通过只向对等体发送聚合后的路由而…...
ASP.NET Core WebAPI的异步及返回值
目录 Action方法的异步 Action方法参数 捕捉URL占位符 捕捉QueryString的值 JSON报文体 其他方式 Action方法的异步 Action方法既可以同步也可以异步。异步Action方法的名字一般不需要以Async结尾。Web API中Action方法的返回值如果是普通数据类型,那么返回值…...

「 机器人 」仿生扑翼飞行器中的“被动旋转机制”概述
前言 在仿生扑翼飞行器的机翼设计中,模仿昆虫翼的被动旋转机制是一项关键技术。其核心思想在于:机翼旋转角度(攻角)并非完全通过主动伺服来控制,而是利用空气动力和惯性力的作用,自然地实现被动调节。以下对这种设计的背景、原理与优势进行详细说明。 1. 背景:昆虫的被动…...
「 机器人 」扑翼飞行器的数据驱动建模核心方法
前言 数据驱动建模可充分利用扑翼飞行器的已有运行数据,改进动力学模型与控制策略,并对未建模动态做出更精确的预测。在复杂的非线性飞行环境中,该方法能有效弥补传统解析建模的不足,具有较高的研究与应用价值。以下针对主要研究方向和实现步骤进行整理与阐述。 1. 数据驱动…...

个人网站搭建
搭建 LNMP环境搭建: LNMP环境指:Linux Nginx MySQL/MariaDB PHP,在debian上安装整体需要300MB的磁盘空间。MariaDB 是 MySQL 的一个分支,由 MySQL 的原开发者维护,通常在性能和优化上有所改进。由于其轻量化和与M…...

飞书项目流程入门指导手册
飞书项目流程入门指导手册 参考资料准备工作新建空间国际化配置新建工作项字段管理新建字段对接标识授权角色 流程管理基础说明流程节点配置流程节点的布局配置页面上布局按钮布局配置 流程节点驳回流程图展示自动化字段修改 局限性 参考资料 飞书官方参考文档:飞书…...

xss靶场
xss-labs下载地址:GitHub - do0dl3/xss-labs: xss 跨站漏洞平台 xss常见触发标签:XSS跨站脚本攻击实例与防御策略-CSDN博客 level-1 首先查看网页的源代码发现get传参的name的值test插入了html里头,还回显了payload的长度。 <!DOCTYPE …...

XML实体注入漏洞攻与防
JAVA中的XXE攻防 回显型 无回显型 cve-2014-3574...

switch组件的功能与用法
文章目录 1 概念介绍2 使用方法3 示例代码 我们在上一章回中介绍了PageView这个Widget,本章回中将介绍Switch Widget.闲话休提,让我们一起Talk Flutter吧。 1 概念介绍 我们在这里介绍的Switch是指左右滑动的开关,常用来表示某项设置是打开还是关闭。Fl…...

cursor重构谷粒商城05——docker容器化技术快速入门【番外篇】
前言:这个系列将使用最前沿的cursor作为辅助编程工具,来快速开发一些基础的编程项目。目的是为了在真实项目中,帮助初级程序员快速进阶,以最快的速度,效率,快速进阶到中高阶程序员。 本项目将基于谷粒商城…...
高等数学学习笔记 ☞ 微分方程
1. 微分方程的基本概念 1. 微分方程的基本概念: (1)微分方程:含有未知函数及其导数或微分的方程。 举例说明微分方程:;。 (2)微分方程的阶:指微分方程中未知函数的导数…...
【探索 Kali Linux】渗透测试与网络安全的终极操作系统
探索 Kali Linux:渗透测试与网络安全的终极操作系统 在网络安全领域,Kali Linux 无疑是最受欢迎的操作系统之一。无论是专业的渗透测试人员、安全研究人员,还是对网络安全感兴趣的初学者,Kali Linux 都提供了强大的工具和灵活的环…...

label-studio的使用教程(导入本地路径)
文章目录 1. 准备环境2. 脚本启动2.1 Windows2.2 Linux 3. 安装label-studio机器学习后端3.1 pip安装(推荐)3.2 GitHub仓库安装 4. 后端配置4.1 yolo环境4.2 引入后端模型4.3 修改脚本4.4 启动后端 5. 标注工程5.1 创建工程5.2 配置图片路径5.3 配置工程类型标签5.4 配置模型5.…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具
文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...
高防服务器能够抵御哪些网络攻击呢?
高防服务器作为一种有着高度防御能力的服务器,可以帮助网站应对分布式拒绝服务攻击,有效识别和清理一些恶意的网络流量,为用户提供安全且稳定的网络环境,那么,高防服务器一般都可以抵御哪些网络攻击呢?下面…...
基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解
JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用,结合SQLite数据库实现联系人管理功能,并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能,同时可以最小化到系统…...

RSS 2025|从说明书学习复杂机器人操作任务:NUS邵林团队提出全新机器人装配技能学习框架Manual2Skill
视觉语言模型(Vision-Language Models, VLMs),为真实环境中的机器人操作任务提供了极具潜力的解决方案。 尽管 VLMs 取得了显著进展,机器人仍难以胜任复杂的长时程任务(如家具装配),主要受限于人…...

Proxmox Mail Gateway安装指南:从零开始配置高效邮件过滤系统
💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storms…...
【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统
Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...

uni-app学习笔记三十五--扩展组件的安装和使用
由于内置组件不能满足日常开发需要,uniapp官方也提供了众多的扩展组件供我们使用。由于不是内置组件,需要安装才能使用。 一、安装扩展插件 安装方法: 1.访问uniapp官方文档组件部分:组件使用的入门教程 | uni-app官网 点击左侧…...

云原生安全实战:API网关Envoy的鉴权与限流详解
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关 作为微服务架构的统一入口,负责路由转发、安全控制、流量管理等核心功能。 2. Envoy 由Lyft开源的高性能云原生…...

篇章一 论坛系统——前置知识
目录 1.软件开发 1.1 软件的生命周期 1.2 面向对象 1.3 CS、BS架构 1.CS架构编辑 2.BS架构 1.4 软件需求 1.需求分类 2.需求获取 1.5 需求分析 1. 工作内容 1.6 面向对象分析 1.OOA的任务 2.统一建模语言UML 3. 用例模型 3.1 用例图的元素 3.2 建立用例模型 …...