当前位置: 首页 > news >正文

反向传播(backward propagation,BP) python实现

BP算法就是反向传播,要输入的数据经过一个前向传播会得到一个输出,但是由于权重的原因,所以其输出会和你想要的输出有差距,这个时候就需要进行反向传播,利用梯度下降,对所有的权重进行更新,这样的话在进行前向传播就会发现其输出和你想要的输出越来越接近了。

# 
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt# 生成权重以及偏执项layers_dim代表每层的神经元个数,
#比如[2,3,1]代表一个三成的网络,输入为2层,中间为3层输出为1层
def init_parameters(layers_dim):L = len(layers_dim)parameters ={}for i in range(1,L):parameters["w"+str(i)] = np.random.random([layers_dim[i],layers_dim[i-1]])parameters["b"+str(i)] = np.zeros((layers_dim[i],1))return parametersdef sigmoid(z):return 1.0/(1.0+np.exp(-z))# sigmoid的导函数
def sigmoid_prime(z):return sigmoid(z) * (1-sigmoid(z))# 前向传播,需要用到一个输入x以及所有的权重以及偏执项,都在parameters这个字典里面存储
# 最后返回会返回一个caches里面包含的 是各层的a和z,a[layers]就是最终的输出
def forward(x,parameters):a = []z = []caches = {}a.append(x)z.append(x)layers = len(parameters)//2# 前面都要用sigmoidfor i in range(1,layers):z_temp =parameters["w"+str(i)].dot(x) + parameters["b"+str(i)]z.append(z_temp)a.append(sigmoid(z_temp))# 最后一层不用sigmoidz_temp = parameters["w"+str(layers)].dot(a[layers-1]) + parameters["b"+str(layers)]z.append(z_temp)a.append(z_temp)caches["z"] = zcaches["a"] = a    return  caches,a[layers]# 反向传播,parameters里面存储的是所有的各层的权重以及偏执,caches里面存储各层的a和z
# al是经过反向传播后最后一层的输出,y代表真实值
# 返回的grades代表着误差对所有的w以及b的导数
def backward(parameters,caches,al,y):layers = len(parameters)//2grades = {}m = y.shape[1]# 假设最后一层不经历激活函数# 就是按照上面的图片中的公式写的grades["dz"+str(layers)] = al - ygrades["dw"+str(layers)] = grades["dz"+str(layers)].dot(caches["a"][layers-1].T) /mgrades["db"+str(layers)] = np.sum(grades["dz"+str(layers)],axis = 1,keepdims = True) /m# 前面全部都是sigmoid激活for i in reversed(range(1,layers)):grades["dz"+str(i)] = parameters["w"+str(i+1)].T.dot(grades["dz"+str(i+1)]) * sigmoid_prime(caches["z"][i])grades["dw"+str(i)] = grades["dz"+str(i)].dot(caches["a"][i-1].T)/mgrades["db"+str(i)] = np.sum(grades["dz"+str(i)],axis = 1,keepdims = True) /mreturn grades   # 就是把其所有的权重以及偏执都更新一下
def update_grades(parameters,grades,learning_rate):layers = len(parameters)//2for i in range(1,layers+1):parameters["w"+str(i)] -= learning_rate * grades["dw"+str(i)]parameters["b"+str(i)] -= learning_rate * grades["db"+str(i)]return parameters
# 计算误差值
def compute_loss(al,y):return np.mean(np.square(al-y))# 加载数据
def load_data():"""加载数据集"""x = np.arange(0.0,1.0,0.01)y =20* np.sin(2*np.pi*x)# 数据可视化plt.scatter(x,y)return x,y
#进行测试
x,y = load_data()
x = x.reshape(1,100)
y = y.reshape(1,100)
plt.scatter(x,y)
parameters = init_parameters([1,25,1])
al = 0
for i in range(4000):caches,al = forward(x, parameters)grades = backward(parameters, caches, al, y)parameters = update_grades(parameters, grades, learning_rate= 0.3)if i %100 ==0:print(compute_loss(al, y))
plt.scatter(x,al)
plt.show()

运行结果:

在这里插入图片描述

相关文章:

反向传播(backward propagation,BP) python实现

BP算法就是反向传播,要输入的数据经过一个前向传播会得到一个输出,但是由于权重的原因,所以其输出会和你想要的输出有差距,这个时候就需要进行反向传播,利用梯度下降,对所有的权重进行更新,这样…...

简单算命脚本

效果展示 文件内容 main.py文件 import json import random import time# 别挂配置数据 gua_data_path "data.json"# 别卦数据 gua_data_map {} fake_delay 10# 读取别卦数据 def init_gua_data(json_path):with open(gua_data_path, r, encodingutf8) as fp:gl…...

Lua-掌握Lua语言基础1

Lua是一种轻量级的脚本语言,广泛应用于游戏开发、嵌入式系统和其他领域。下面是Lua语言基础的介绍: 数据类型:Lua支持基本的数据类型,包括nil、boolean、number、string和table。其中,table是一种关联数组,…...

python-0003-pycharm开发虚拟环境中的项目

前言 在虚拟环境中创建好了python项目,使用pycharm进行开发 打开项目 使用pycharm打开项目 设置虚拟环境的解释器 File–>Settings–>Project(项目名)–>Python Interpreter–>添加解释器–>添加已经存在的解释器–>选择虚拟环境的解释器 …...

修改 MySQL update_time 默认值的坑

由于按规范需要对 update_time 字段需要对它做默认值的设置 现在有一个原始的表是这样的 CREATE TABLE test_up (id bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 主键id,update_time datetime default null COMMENT 操作时间,PRIMARY KEY (id) ) ENGINEInnoDB DEF…...

基于亚马逊云EC2+Docker搭建nextcloud私有化云盘

亚马逊云科技EC2云服务器(Elastic Compute Cloud)是亚马逊云科技AWS(Amazon Web Services)提供的一种云计算服务。EC2代表弹性计算云,它允许用户租用虚拟计算资源,包括CPU、内存、存储和网络带宽&#xff0…...

用try...catch进行判断

在写一些提交数据的判断上,有时候会写下面的ifelse的判断方法,少一点还好,多的话就很难受也不好看。 if(!that.driverObj.contrary){this.__utils.showToast(请先上传驾驶证副页图片);return false } if(!this.driverObj.start){this.__util…...

服务器遭遇挖矿病毒syst3md及其伪装者rcu-sched:原因、症状与解决方案

01 什么是挖矿病毒 挖矿病毒通常是恶意软件的一种,它会在受感染的系统上无授权地挖掘加密货币。关于"syst3md",是一种特定的挖矿病毒,它通过在受感染的Linux系统中执行一系列复杂操作来达到其目的。这些操作包括使用curl从网络下载…...

此机非彼机,工业计算机在工业行业的特殊地位

电子计算机,称为电脑。计算机是一种利用数字电子技术,根据一系列指令指示其自动执行任意算术或逻辑操作串行的设备。通用计算机因有能遵循被称为“程序”的一般操作集的能力而使得它们能够执行极其广泛的任务,以管理各种工厂和机器自动化工业…...

Python使用lxml解析XML格式化数据

Python使用lxml解析XML格式化数据 1. 效果图2. 源代码参考 方法一:无脑读取文件,遇到有关键词的行再去解析获取值 方法二:利用lxml等库,解析格式化数据,批量获取标签及其值 这篇博客介绍第2种办法,以菜鸟教…...

CDA-LevelⅡ【考题整理-带答案】

关于相关分析中应注意的问题,下面说法错误的是:B 如果两变量间的相关系数为0,则说明二者独立 。解释:只能说明两者不存在线性相关关系现通过参数估计得到一个一元线性回归模型为y3x4,在回归系数检验中下列说法错误的是…...

20240304 json可以包含复杂数组(数组里面套数组)

欣赏一下我的思维,它会以漫画,表格,文字。。。各种各样的形式呈现 对于问题1问题2 JSON (JavaScript Object Notation) 是一种轻量级的数据交换格式,易于人阅读和编写,同时也易于机器解析和生成。JSON本质上是一种文本…...

算法50:动态规划专练(力扣514题:自由之路-----4种写法)

题目: 力扣514 : 自由之路 . - 力扣(LeetCode) 题目的详细描述,直接打开力扣看就是了,下面说一下我对题目的理解: 事例1: 输入: ring "godding", key "gd" 输出: 4. 1. ring的第…...

重学SpringBoot3-集成Thymeleaf

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Thymeleaf 1. 添加Thymeleaf依赖2. 配置Thymeleaf属性(可选)3. 创建Thymeleaf模板4. 创建一个Controller5. 运行应用并访问页…...

【数据可视化】Echarts最常用图表

个人主页 : zxctscl 如有转载请先通知 文章目录 1. 前言2. 准备工作3. 柱状图3.1 绘制堆积柱状图3.2 绘制标准条形图3.3 绘制瀑布图 4. 折线图4.1 绘制堆积面积图和堆积折线图4.2 绘制阶梯图 5. 饼图5.1 绘制标准饼图5.2 绘制圆环图5.2 绘制嵌套饼图5.3 绘制南丁格尔…...

flink:通过table api把文件中读取的数据写入MySQL

当写入数据到外部数据库时,Flink 会使用 DDL 中定义的主键。如果定义了主键,则连接器将以 upsert 模式工作,否则连接器将以 append 模式工作 package cn.edu.tju.demo2;import org.apache.flink.streaming.api.environment.StreamExecutionE…...

【Java 多线程 哈希表】 HashTable, HashMap, ConcurrentHashMap 之间的区别

HashTable、HashMap和ConcurrentHashMap都是Java中用于存储键值对的集合框架的一部分,但它们之间存在一些重要的联系和区别。 联系 键值对存储:它们都用于存储键值对,并允许你根据键来检索值。基于哈希:它们内部都使用了哈希表来…...

有趣之matlab-烟花

待整合1 2 3 动态 有趣编程之11 静态 逼真 3 .m文件路径下放back1.jpg back4.jpg…背景照片 点击screen 就会有小白点升起,爆炸 function yanhuamoban()clear all;%定义全局变量global ah ;%坐标轴句柄global styleNum ;%爆炸图案样式global multiColor; %多颜色变换…...

C语言指针与数组(不适合初学者版):一篇文章带你深入了解指针与数组!

🎈个人主页:JAMES别扣了 💕在校大学生一枚。对IT有着极其浓厚的兴趣 ✨系列专栏目前为C语言初阶、后续会更新c语言的学习方法以及c题目分享. 😍希望我的文章对大家有着不一样的帮助,欢迎大家关注我,我也会回…...

springboot Mongo大数据查询优化方案

前言 因为项目需要把传感器的数据保存起来,当时设计的时是mongo来存储,后期需要从mongo DB里查询传感器的数据记录。由于传感器每秒都会像mongo数据库存500条左右的数据,1天就有4320万条数据,要想按照时间条件去查询,…...

React hook之useRef

React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明&#xff1a;server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

Mac下Android Studio扫描根目录卡死问题记录

环境信息 操作系统: macOS 15.5 (Apple M2芯片)Android Studio版本: Meerkat Feature Drop | 2024.3.2 Patch 1 (Build #AI-243.26053.27.2432.13536105, 2025年5月22日构建) 问题现象 在项目开发过程中&#xff0c;提示一个依赖外部头文件的cpp源文件需要同步&#xff0c;点…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表

##鸿蒙核心技术##运动开发##Sensor Service Kit&#xff08;传感器服务&#xff09;# 前言 在运动类应用中&#xff0c;运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据&#xff0c;如配速、距离、卡路里消耗等&#xff0c;用户可以更清晰…...

Yolov8 目标检测蒸馏学习记录

yolov8系列模型蒸馏基本流程&#xff0c;代码下载&#xff1a;这里本人提交了一个demo:djdll/Yolov8_Distillation: Yolov8轻量化_蒸馏代码实现 在轻量化模型设计中&#xff0c;**知识蒸馏&#xff08;Knowledge Distillation&#xff09;**被广泛应用&#xff0c;作为提升模型…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...