反向传播(backward propagation,BP) python实现
BP算法就是反向传播,要输入的数据经过一个前向传播会得到一个输出,但是由于权重的原因,所以其输出会和你想要的输出有差距,这个时候就需要进行反向传播,利用梯度下降,对所有的权重进行更新,这样的话在进行前向传播就会发现其输出和你想要的输出越来越接近了。
#
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt# 生成权重以及偏执项layers_dim代表每层的神经元个数,
#比如[2,3,1]代表一个三成的网络,输入为2层,中间为3层输出为1层
def init_parameters(layers_dim):L = len(layers_dim)parameters ={}for i in range(1,L):parameters["w"+str(i)] = np.random.random([layers_dim[i],layers_dim[i-1]])parameters["b"+str(i)] = np.zeros((layers_dim[i],1))return parametersdef sigmoid(z):return 1.0/(1.0+np.exp(-z))# sigmoid的导函数
def sigmoid_prime(z):return sigmoid(z) * (1-sigmoid(z))# 前向传播,需要用到一个输入x以及所有的权重以及偏执项,都在parameters这个字典里面存储
# 最后返回会返回一个caches里面包含的 是各层的a和z,a[layers]就是最终的输出
def forward(x,parameters):a = []z = []caches = {}a.append(x)z.append(x)layers = len(parameters)//2# 前面都要用sigmoidfor i in range(1,layers):z_temp =parameters["w"+str(i)].dot(x) + parameters["b"+str(i)]z.append(z_temp)a.append(sigmoid(z_temp))# 最后一层不用sigmoidz_temp = parameters["w"+str(layers)].dot(a[layers-1]) + parameters["b"+str(layers)]z.append(z_temp)a.append(z_temp)caches["z"] = zcaches["a"] = a return caches,a[layers]# 反向传播,parameters里面存储的是所有的各层的权重以及偏执,caches里面存储各层的a和z
# al是经过反向传播后最后一层的输出,y代表真实值
# 返回的grades代表着误差对所有的w以及b的导数
def backward(parameters,caches,al,y):layers = len(parameters)//2grades = {}m = y.shape[1]# 假设最后一层不经历激活函数# 就是按照上面的图片中的公式写的grades["dz"+str(layers)] = al - ygrades["dw"+str(layers)] = grades["dz"+str(layers)].dot(caches["a"][layers-1].T) /mgrades["db"+str(layers)] = np.sum(grades["dz"+str(layers)],axis = 1,keepdims = True) /m# 前面全部都是sigmoid激活for i in reversed(range(1,layers)):grades["dz"+str(i)] = parameters["w"+str(i+1)].T.dot(grades["dz"+str(i+1)]) * sigmoid_prime(caches["z"][i])grades["dw"+str(i)] = grades["dz"+str(i)].dot(caches["a"][i-1].T)/mgrades["db"+str(i)] = np.sum(grades["dz"+str(i)],axis = 1,keepdims = True) /mreturn grades # 就是把其所有的权重以及偏执都更新一下
def update_grades(parameters,grades,learning_rate):layers = len(parameters)//2for i in range(1,layers+1):parameters["w"+str(i)] -= learning_rate * grades["dw"+str(i)]parameters["b"+str(i)] -= learning_rate * grades["db"+str(i)]return parameters
# 计算误差值
def compute_loss(al,y):return np.mean(np.square(al-y))# 加载数据
def load_data():"""加载数据集"""x = np.arange(0.0,1.0,0.01)y =20* np.sin(2*np.pi*x)# 数据可视化plt.scatter(x,y)return x,y
#进行测试
x,y = load_data()
x = x.reshape(1,100)
y = y.reshape(1,100)
plt.scatter(x,y)
parameters = init_parameters([1,25,1])
al = 0
for i in range(4000):caches,al = forward(x, parameters)grades = backward(parameters, caches, al, y)parameters = update_grades(parameters, grades, learning_rate= 0.3)if i %100 ==0:print(compute_loss(al, y))
plt.scatter(x,al)
plt.show()
运行结果:

相关文章:
反向传播(backward propagation,BP) python实现
BP算法就是反向传播,要输入的数据经过一个前向传播会得到一个输出,但是由于权重的原因,所以其输出会和你想要的输出有差距,这个时候就需要进行反向传播,利用梯度下降,对所有的权重进行更新,这样…...
简单算命脚本
效果展示 文件内容 main.py文件 import json import random import time# 别挂配置数据 gua_data_path "data.json"# 别卦数据 gua_data_map {} fake_delay 10# 读取别卦数据 def init_gua_data(json_path):with open(gua_data_path, r, encodingutf8) as fp:gl…...
Lua-掌握Lua语言基础1
Lua是一种轻量级的脚本语言,广泛应用于游戏开发、嵌入式系统和其他领域。下面是Lua语言基础的介绍: 数据类型:Lua支持基本的数据类型,包括nil、boolean、number、string和table。其中,table是一种关联数组,…...
python-0003-pycharm开发虚拟环境中的项目
前言 在虚拟环境中创建好了python项目,使用pycharm进行开发 打开项目 使用pycharm打开项目 设置虚拟环境的解释器 File–>Settings–>Project(项目名)–>Python Interpreter–>添加解释器–>添加已经存在的解释器–>选择虚拟环境的解释器 …...
修改 MySQL update_time 默认值的坑
由于按规范需要对 update_time 字段需要对它做默认值的设置 现在有一个原始的表是这样的 CREATE TABLE test_up (id bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT 主键id,update_time datetime default null COMMENT 操作时间,PRIMARY KEY (id) ) ENGINEInnoDB DEF…...
基于亚马逊云EC2+Docker搭建nextcloud私有化云盘
亚马逊云科技EC2云服务器(Elastic Compute Cloud)是亚马逊云科技AWS(Amazon Web Services)提供的一种云计算服务。EC2代表弹性计算云,它允许用户租用虚拟计算资源,包括CPU、内存、存储和网络带宽࿰…...
用try...catch进行判断
在写一些提交数据的判断上,有时候会写下面的ifelse的判断方法,少一点还好,多的话就很难受也不好看。 if(!that.driverObj.contrary){this.__utils.showToast(请先上传驾驶证副页图片);return false } if(!this.driverObj.start){this.__util…...
服务器遭遇挖矿病毒syst3md及其伪装者rcu-sched:原因、症状与解决方案
01 什么是挖矿病毒 挖矿病毒通常是恶意软件的一种,它会在受感染的系统上无授权地挖掘加密货币。关于"syst3md",是一种特定的挖矿病毒,它通过在受感染的Linux系统中执行一系列复杂操作来达到其目的。这些操作包括使用curl从网络下载…...
此机非彼机,工业计算机在工业行业的特殊地位
电子计算机,称为电脑。计算机是一种利用数字电子技术,根据一系列指令指示其自动执行任意算术或逻辑操作串行的设备。通用计算机因有能遵循被称为“程序”的一般操作集的能力而使得它们能够执行极其广泛的任务,以管理各种工厂和机器自动化工业…...
Python使用lxml解析XML格式化数据
Python使用lxml解析XML格式化数据 1. 效果图2. 源代码参考 方法一:无脑读取文件,遇到有关键词的行再去解析获取值 方法二:利用lxml等库,解析格式化数据,批量获取标签及其值 这篇博客介绍第2种办法,以菜鸟教…...
CDA-LevelⅡ【考题整理-带答案】
关于相关分析中应注意的问题,下面说法错误的是:B 如果两变量间的相关系数为0,则说明二者独立 。解释:只能说明两者不存在线性相关关系现通过参数估计得到一个一元线性回归模型为y3x4,在回归系数检验中下列说法错误的是…...
20240304 json可以包含复杂数组(数组里面套数组)
欣赏一下我的思维,它会以漫画,表格,文字。。。各种各样的形式呈现 对于问题1问题2 JSON (JavaScript Object Notation) 是一种轻量级的数据交换格式,易于人阅读和编写,同时也易于机器解析和生成。JSON本质上是一种文本…...
算法50:动态规划专练(力扣514题:自由之路-----4种写法)
题目: 力扣514 : 自由之路 . - 力扣(LeetCode) 题目的详细描述,直接打开力扣看就是了,下面说一下我对题目的理解: 事例1: 输入: ring "godding", key "gd" 输出: 4. 1. ring的第…...
重学SpringBoot3-集成Thymeleaf
更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Thymeleaf 1. 添加Thymeleaf依赖2. 配置Thymeleaf属性(可选)3. 创建Thymeleaf模板4. 创建一个Controller5. 运行应用并访问页…...
【数据可视化】Echarts最常用图表
个人主页 : zxctscl 如有转载请先通知 文章目录 1. 前言2. 准备工作3. 柱状图3.1 绘制堆积柱状图3.2 绘制标准条形图3.3 绘制瀑布图 4. 折线图4.1 绘制堆积面积图和堆积折线图4.2 绘制阶梯图 5. 饼图5.1 绘制标准饼图5.2 绘制圆环图5.2 绘制嵌套饼图5.3 绘制南丁格尔…...
flink:通过table api把文件中读取的数据写入MySQL
当写入数据到外部数据库时,Flink 会使用 DDL 中定义的主键。如果定义了主键,则连接器将以 upsert 模式工作,否则连接器将以 append 模式工作 package cn.edu.tju.demo2;import org.apache.flink.streaming.api.environment.StreamExecutionE…...
【Java 多线程 哈希表】 HashTable, HashMap, ConcurrentHashMap 之间的区别
HashTable、HashMap和ConcurrentHashMap都是Java中用于存储键值对的集合框架的一部分,但它们之间存在一些重要的联系和区别。 联系 键值对存储:它们都用于存储键值对,并允许你根据键来检索值。基于哈希:它们内部都使用了哈希表来…...
有趣之matlab-烟花
待整合1 2 3 动态 有趣编程之11 静态 逼真 3 .m文件路径下放back1.jpg back4.jpg…背景照片 点击screen 就会有小白点升起,爆炸 function yanhuamoban()clear all;%定义全局变量global ah ;%坐标轴句柄global styleNum ;%爆炸图案样式global multiColor; %多颜色变换…...
C语言指针与数组(不适合初学者版):一篇文章带你深入了解指针与数组!
🎈个人主页:JAMES别扣了 💕在校大学生一枚。对IT有着极其浓厚的兴趣 ✨系列专栏目前为C语言初阶、后续会更新c语言的学习方法以及c题目分享. 😍希望我的文章对大家有着不一样的帮助,欢迎大家关注我,我也会回…...
springboot Mongo大数据查询优化方案
前言 因为项目需要把传感器的数据保存起来,当时设计的时是mongo来存储,后期需要从mongo DB里查询传感器的数据记录。由于传感器每秒都会像mongo数据库存500条左右的数据,1天就有4320万条数据,要想按照时间条件去查询,…...
应用升级/灾备测试时使用guarantee 闪回点迅速回退
1.场景 应用要升级,当升级失败时,数据库回退到升级前. 要测试系统,测试完成后,数据库要回退到测试前。 相对于RMAN恢复需要很长时间, 数据库闪回只需要几分钟。 2.技术实现 数据库设置 2个db_recovery参数 创建guarantee闪回点,不需要开启数据库闪回。…...
论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)
HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...
DockerHub与私有镜像仓库在容器化中的应用与管理
哈喽,大家好,我是左手python! Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库,用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...
基础测试工具使用经验
背景 vtune,perf, nsight system等基础测试工具,都是用过的,但是没有记录,都逐渐忘了。所以写这篇博客总结记录一下,只要以后发现新的用法,就记得来编辑补充一下 perf 比较基础的用法: 先改这…...
DBAPI如何优雅的获取单条数据
API如何优雅的获取单条数据 案例一 对于查询类API,查询的是单条数据,比如根据主键ID查询用户信息,sql如下: select id, name, age from user where id #{id}API默认返回的数据格式是多条的,如下: {&qu…...
前端开发面试题总结-JavaScript篇(一)
文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包(Closure)?闭包有什么应用场景和潜在问题?2.解释 JavaScript 的作用域链(Scope Chain) 二、原型与继承3.原型链是什么?如何实现继承&a…...
CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
C++ Visual Studio 2017厂商给的源码没有.sln文件 易兆微芯片下载工具加开机动画下载。
1.先用Visual Studio 2017打开Yichip YC31xx loader.vcxproj,再用Visual Studio 2022打开。再保侟就有.sln文件了。 易兆微芯片下载工具加开机动画下载 ExtraDownloadFile1Info.\logo.bin|0|0|10D2000|0 MFC应用兼容CMD 在BOOL CYichipYC31xxloaderDlg::OnIni…...
OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 在 GPU 上对图像执行 均值漂移滤波(Mean Shift Filtering),用于图像分割或平滑处理。 该函数将输入图像中的…...
