【初学人工智能原理】【4】梯度下降和反向传播:能改(下)
前言
本文教程均来自b站【小白也能听懂的人工智能原理】,感兴趣的可自行到b站观看。
本文【原文】章节来自课程的对白,由于缺少图片可能无法理解,故放到了最后,建议直接看代码(代码放到了前面)。
代码实现
任务
- 在引入b后绘制代价函数界面,看看到底是不是一个碗
- 在w和b两个方向上分别求导,得到这个曲面某点的梯度进行梯度下降,拟合数据
绘制三维的方差代价函数
dataset.py
import numpy as npdef get_beans(counts):xs = np.random.rand(counts)xs = np.sort(xs)ys = np.array([(0.7*x+(0.5-np.random.rand())/5+0.5) for x in xs])return xs,ys
cost_function_w.py
import dataset
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
m = 100
xs, ys = dataset.get_beans(m)# 配置图像
plt.title("Size-Toxicity Function'", fontsize=12) # 设置图像名称
plt.xlabel("Bean size") # 设置横坐标的名字
plt.ylabel("Toxicity") # 设置纵坐标的名字
plt.xlim(0, 1)
plt.ylim(0, 1.5)
plt.scatter(xs, ys)w = 0.1
b = 0.1
y_pre = w * xs * bplt.plot(xs, y_pre)
plt.show()fig=plt.figure()
ax=Axes3D(fig)
ax.set_zlim(0,2) # 限制垂直方向坐标轴取值范围ws=np.arange(-1,2,0.1)
bs=np.arange(-2,2,0.01)for b in bs:es=[]for w in ws:y_pre=w*xs+be=np.sum((ys-y_pre)**2)*(1/m)es.append(e)# plt.plot(ws,es)ax.plot(ws,es,b,zdir='y') # 3D绘图plt.show()
关于w和b的梯度下降算法
import dataset
import numpy as np
from matplotlib import pyplot as plt## Create a dataset
n = 100
xs, ys = dataset.get_beans(n)# 配置图像
plt.title("Size-Toxicity Function",fontsize=12)
plt.xlabel("Bean Size")
plt.ylabel("Toxicity")
plt.scatter(xs,ys)w=0.1
b=0.1
y_pre=w*xs+b
plt.plot(xs,y_pre)
plt.show()# 随机梯度下降法
def gsd(w=0.1,b=0.1):# 在全部样本上做50次梯度下降for _ in range(50):for i in range(100):x = xs[i]y = ys[i]# a=x^2# b=-2*x*y# c=y^2# 斜率k=2aw+bdw = 2*w**2*w+2*x*b-2*x*y # e对w求导db = 2 * b + 2 * x * w - 2 * y # e对b求导alpha = 0.1w = w - alpha * dw # w根据梯度下降的方向走,如w此时的k<0,则w处于抛物线左端,应该往右边走,相反则往左边走b=b-alpha*db# 绘制动态变化的曲线plt.clf() # 清空窗口plt.scatter(xs, ys)y_pre = w * xs+b# 限制x轴和y轴的范围,使之不自动调整,避免图像抖动plt.xlim(0, 1)plt.ylim(0, 1.2)plt.plot(xs, y_pre)plt.pause(0.01) # 暂停0.01s,因为不暂停的话会无法显示# 随机梯度下降
gsd()
原文
通过前面的学习,我们已然了解到现在神经网络精髓之一的梯度下降算法,但是如果仔细观察我们设计的预测函数,你就会发现这是一个非常危险和不完善的模型。比如在另外一片海域里,豆豆的大小和毒性的关系是这样的,有些太小的豆豆是不存在的,我们发现不论怎样去调整w都无法得到理想的预测函数。当然更加糟糕的情况是豆豆越大,毒性越低,原因很简单,我们的预测函数y等于w乘以x很明显是一个必须经过原点的直线。
换句话说,这个预测函数直线的自由度被限制住了,只能旋转而不能移动。因为大家很清楚,一个直线完整的函数应该是y等于wx加b。
之前我们为了遵循如无必要物增新知的理念,一直在刻意的避免这个截距参数b。直到现在我们终于避无可避是时候增加新的知识。截距参数b的作用大家很清楚,可以让直线在平面内自由的平移,而斜率w可以让直线自由的旋转,当我们把直线的平移的自由度还给它之后,这两者的结合才能让直线在整个平面内这真正的自由起来。
我们来看一下加入截距参数b后发生的改变。首先我们带入b重新推演一次预测和梯度下降的过程,当然为了简单起见,我们还是先看单个豆豆样本的情况,这是预测函数。豆豆的大小是x0,毒性是y0,那么预测就是w乘以x0+b那么根据方差代价函数得到方差代价是这样的,你会发现没有b的时候或者说b=0的时候,代价函数就是我们前几节课中的样子,这其实是b=0的一种特殊情况。
那现在既然有了b接下来我们就要看这个b取不同值的时候会对代价函数造成什么样子的影响?
这里我们需要把代价函数的图像从二维变成三维,给b留出一个维度,b=0就是我们之前讨论的抛物线,b=0.1,e和w的关系还是一个标准的开口向上的抛物线,因为b的改变只影响这个抛物线的系数,换句话说改变的只是抛物线的具体样子,而不会让它变成其他形状。同样的道理,b=0.2也是,b=0.3也是等等等,我们好像已经看出一些眉目了,这好像是一个曲面。没错,这里我们的b取值间隔是0.1,描绘出来的效果似乎还是不太明显。当我们把b的取值间隔弄小一点,再小一点,直到无限的小下去,这时候你就会发现果真是一个三维空间中的曲面,那我们该如何去看待这个曲面?
在有些教程和书籍中,很多时候为了让他看着明显把它画成了一个鼓鼓的碗状。其实对于线性回归问题中,这种豆豆数据形成的代价函数实际上并没有那么的鼓,而是一个扁扁的碗,扁的几乎看不出来是个碗,但是当我们把这个曲面的等高线画出来就可以看出来,这确实也是一个碗。很明显这个碗状曲面的最低点肯定是问题的关键所在。我们回忆一下,在没有b出现的时候,曲线的最低点代表着w取值造成的预测误差最小,那这个曲面最低点意味着什么?
首先我们想一想这个最低点是怎么形成的呢?
没错,我们每次取不同的w和b都会导致误差,e不相同,这个局面也就是我们带入b后得到的代价函数的图像,而它的最低点也就意味着这里的w和b的取值会让预测误差最小。而如果我们能得到这个最低点的w和b的值放回到预测函数中,那么此时此刻恰如彼时必克预测也就是最好的。现在我们的目标就很明确了,如何在这个曲面上取得最低点处的w和b的值?在没有b出现的美好时刻,也就是说在b=0处,我们沿着w的方向切上一刀,我们知道这将形成一个关于e和w的开口向上的抛物线,然后不断的通过梯度下降算法去调整w最后到达最低点。
但是你会发现此刻曲线的最低点却并不是这个曲面的最低点,换句话说b=0的取值并不是最好的,那么关于b套路其实还是一样的,我们在这一点上,如果沿着b的方向给曲线来上,一刀会怎样呢?你会发现切口形成的曲线似乎也是一个开口向上的抛物线,如果是这样的话那就很nice了,我们在这个抛物线上也向最低点挪动即可,但果真如此吗?我们之前已经分析出来e和w的关系是一个抛物线,现在我们不妨再看一下e和b的关系,这是方差代价函数,要研究b那我们围绕b重新整理一下这个式子,是这样的。
当w确定的时候,也就是我们沿着b的方向切下一刀,比如当前这个点的w值为w cut,这时候代价函数是这样的,你看当w取固定值的时候,也就是把w看作一个确定值的时候,e和b的关系又是一个标准的开口向上的一元二次函数。所以面对现在这个误差代价函数曲面,我们还可以换个角度去理解它的形成方式。除了可以像一开始那样认为是e关于w的一元二次函数曲线在b取不同值的时候形成的以外,也可以认为是e关于b的一元二次函数曲线,在每次w取不同值的时候形成的。
现在我们在b上要做的事情和在w上一模一样,不断的去调整b仍然向这个曲线的最低点挪动,而具体的方法也是一样的,根据斜率进行下降。
我们完整的来看一下这个过程,假设一开始我们的w等于0.1,b也等于0.1,对应的e是这么多,在曲面的这个位置,我们画一个球来显示这个点,正所谓横看成岭侧成峰,我们横看此处看见的是b确定的时候e和w形成的一个曲线,根据此处的斜率调整w大小是斜率乘以学习率阿尔法,方向是根据斜率的正负确定的。我们侧看此处看见的是w确定的时候,e和b形成的一个曲线,根据此处的斜率调整,b大小是斜率乘以学习率阿尔法方向根据斜率振幅确定,把这两个方向上的调整运动合成一个合成的调整运动,这样我们就完成了一次调整,到达下一个点之后,我们继续横看调整w侧看调整b当我们反复进行这个过程的时候,也就逐渐的向这个曲面的最低点挪动了。
所以说这里同时有w和b的代价函数曲面和只有w的代价函数曲线相比,这个下降的过程本质上是一样的,换汤不换药,或者说只是从w一味药换成了w和b两位味药。
但是有一点,我们的代价函数已经是一个曲面了,那这个下降的过程,如果我们再说是斜率下降就有点不太合适了,毕竟一个曲面上的某点的斜率是个什么东西,是关于w的还是关于b的呢?要回答这个问题,需要发散一下思维,换一个角度来看这个下降的过程。我们在代价函数的w和b两个方向上分别求得斜率或者说倒数。
对于有两个自变量的代价函数,我们先偏向w求导数,再偏向b求导数,为了区分只有一个自变量的情况,我们把在某一个变量上的导数也称之为偏导数。如果我们把对w和对b的偏导数看作向量,把这两个向量合在一起,形成一个新的和向量,沿着这个和向量进行了下降,是这个曲面在该点下降最快的方式,这个和向量在数学里称之为梯度,到此为止。你也就理解了为什么我们说梯度是比斜率更加广泛的一个概念,它是把各个方向上的偏导数当做向量,合起来形成一个总向量,代表了这个点下降最快的方向,当然在二维曲线中因为没有其他方向,梯度和斜率也可以认为是一回事,而为了让这个下降算法的名字更具有广泛性,所以我们一般称之为梯度下降,而不是斜率下降。
子曰学而时习之,我们已经完整的讲述了梯度下降的过程,那么现在就来回顾总结一下目前为止我们所学到的东西。我们从环境中观察到了一个问题,豆豆的毒性和它的大小有关系,那现在想要准确的去预测这个关系到底是什么。按照McCulloch-Pitts神经元模型,我们使用一个一元一次线性函数去模拟神经元的树突和轴突的行为,这就是预测函数模型,而把我们统计观测而来的数据送入预测函数进行预测的过程就称之为前向传播。因为计算从前往后数据通过预测函数完成一次前向传播,就会得到一个预测值,预测值和统计观测而来的真实值之间存在着误差。
我们选择平方误差作为评估的手段,你会发现这个误差和预测函数中的参数又会形成一种函数关系,我们把这个函数称之为代价函数,因为采用方差去评估预测误差,所以也称之为方差代价函数,描述了预测函数的参数取不同值的时候预测的不同的误差代价,而用这个代价函数去修正预测函数参数的过程也称之为反向传播。
因为计算从后往前,而这个反向传播参数修正的方法,我们使用梯度下降算法噢对了,我的老伙计。在没有截距b的二维代价函数中叫他斜率下降也未尝不可,而在调整的过程中用来调和下降幅度的l法称之为学习率,他的选择影响了调整的速度太大了容易反复横跳,过大的时候甚至不会收敛,而是发散太小了又容易磨磨唧唧,他是设计者根据经验选择出来的,而不断的经历前向传播和反向传播,最后到达代价函数最低点的过程,我们称之为训练或者学习。
这就是所谓的机器学习中的神经网络,但把一个神经元称之为网络似乎不太恰当,因为没有哪一个网络只有一个节点,但以后我们不断的添加神经元,并把它们连接起来,共同工作的时候,也就能称之为神经网络。
而我们所说的前向传播和反向传播,其实也是在多层神经网络出现后才引入了概念,对于单个神经元如此称呼似乎有点别扭,但这些概念在单个神经元上已经初具雏形,面对网络那只是不断的重复而已,我们会在后面学习多层神经网络时候详细说明相传播更一般的行为,这就是目前人工智能机器学习领域独领风骚的连接主义在干的事情。至于为什么这样多个神经元组合成神经网络后就能达到智能的效果,别着急,我们会在接下来的课程中慢慢到来。
相关文章:

【初学人工智能原理】【4】梯度下降和反向传播:能改(下)
前言 本文教程均来自b站【小白也能听懂的人工智能原理】,感兴趣的可自行到b站观看。 本文【原文】章节来自课程的对白,由于缺少图片可能无法理解,故放到了最后,建议直接看代码(代码放到了前面)。 代码实…...
微信小程序路由传参
微信小程序路由传参 在微信小程序中,可以通过路由传参将数据传递给目标页面。以下是一种常见的方式: 在源页面中,使用 wx.navigateTo 或 wx.redirectTo 方法跳转到目标页面,并通过 URL 参数传递数据。示例: wx.navi…...

深入篇【C++】类与对象:再谈构造函数之初始化列表与explicit关键字
深入篇【C】类与对象:再谈构造函数之初始化列表与explicit关键字 Ⅰ.再谈构造函数①.构造函数体赋值②.初始化列表赋值【<特性分析>】1.至多性2.特殊成员必在性3.必走性:定义位置4.一致性5.不足性 Ⅱ.explicit关键字①.隐式类型转化②.作用 Ⅰ.再谈…...
广东棒球发展建设·棒球1号位
一、概述 棒球是一项源于美国的运动,自20世纪初开始传入中国,近年来在广东省的发展也逐渐受到关注。本文将就广东棒球的发展现状及未来发展方向进行分析。 二、发展现状 目前广东省内棒球赛事主要有以下几种: 1. 业余棒球联赛:…...

浅谈PMO对组织战略的支持︱美团骑行事业部项目管理中心负责人边国华
美团骑行事业部项目管理中心负责人边国华先生受邀为由PMO评论主办的2023第十二届中国PMO大会演讲嘉宾,演讲议题:浅谈PMO对组织战略的支持。大会将于6月17-18日在北京举办,更多内容请浏览会议日程 议题内容简要: 战略是组织运行的…...

互联网医院资质代办|互联网医院牌照的申请流程
随着互联网技术的不断发展,互联网医疗已经逐渐成为人们关注的热点话题。而互联网医院作为互联网医疗的一种重要形式,也越来越受到社会各界的关注。若想开展互联网医院业务,则需要具备互联网医院牌照。那么互联网医院牌照的申请流程和需要的资…...
网络:DPDK复习相关知识点_2
1.RTC运行至完成时模式,单核单模块 2.pipeline模式,多核多模块,每个模块都是一个处理引擎,但会有缓存一致性问题 3.Mbuff数据包内存操作对象,相当于是数据包的一个索引,对网络的处理都集中在这个Buff上 …...
阿里云大学考试Java中级题目及解析-java中级
阿里云大学考试Java中级题目及解析 1.servlet释放资源的方法是? A.int()方法 B.service()方法 C.close() 方法 D.destroy()方法 D servlet释放资源的方法是destroy() 2.order by与 group by的区别? A.order by用于排序,group by用于排序…...

【星戈瑞】Sulfo-CY3-COOH磺化/水溶性Cyanine3羧酸1121756-11-3
Sulfo-CY3 COOH是一种荧光染料,其分子结构中含有COOH官能团,最大吸收波长为550纳米左右,可以通过分光光度计等设备进行检测。Sulfo-CY3 COOH是一种带有羧基的荧光染料,可以与含有氨基的生物分子通过偶联反应形成共价键,…...
Java NIO和IO的主要区别
当学习了Java NIO和IO的API后,一个问题马上涌入脑海: 我应该何时使用IO,何时使用NIO呢?在本文中,我会尽量清晰地解析Java NIO和IO的差异、它们的使用场景,以及它们如何影响您的代码设计。 下表总结了Java N…...
SQL查询语句
DQL语句--排序查询 # 格式: select * from 表名 order by 要排序的列1 [asc/desc], 要排序的列2 [asc/desc]; # 解释: # 1. 无论SQL语句简单或者是复杂, order by语句一般都放最后, 注意: 如果有limit(分页), 则它(limit)在最后. # 2. asc表示升序, desc表示降序, 其中, 默…...

四象限法进程调度
周二收到一篇推送 一次云上网络毫秒级的优化与实践,很有意义的实践和探索,建议阅读,文章不长,没有冗长的源码分析,结论很清晰。 谈谈我的看法。 多少有种感觉,Linux 越来越像个响应系统而不是服务器。 虚…...

蓝桥杯拿到一等奖,并分享经验
昨天和群里的小伙伴在群里聊,有的小伙伴竟然说蓝桥杯一等奖没有含量,我也是醉了! 就像去年看了一个号主写的:研究生遍地都是! 放眼全国14亿人口,别说研究生了,本科生占比有多少? “蓝桥杯是我人生中得到…...
vue3。 Cannot use JSX unless the ‘–jsx’ flag is provided. ts(17004)
react用tsx或者jsx很常见,也有配套的配置 那如果是vue呢? 默认是没问题的,可是我用了jsdoc,并开启了checkjs,然后vscode就爆红了 谷歌,百度,一个晚上 查到的答案: 推荐我新增tsco…...
HVV面试题目总结
蓝队 如何识别安全设备中的无效告警? 常见的端口有哪些? 这些端口对应的服务是什么? 针对这些服务,红队攻击方式有哪些? 常用的威胁情报平台有哪些? 有没有做过关于情报输出的工作? 木马驻留系统的方式有哪些? 当收到钓鱼邮件的时候,说说处置思路…...

Access denied for user ‘root‘@‘localhost‘ (using password:YES) 解决方案
文章目录 问题描述解决方案: 问题描述 Access denied for user ‘root’‘localhost’:拒绝用户’root’localhost’的访问。 出现这个报错语句的一般原因是输入了错误的密码,也有可能是是root帐户默认不开放远程访问权限。 相关的解决方法是重新设置…...

为什么C++这么复杂还不被淘汰?
C是一门广泛使用的编程语言,主要用于系统和应用程序的开发。尽管C具有一些复杂的语法和概念,但它仍然是编程界的重量级选手,在编程语言排行榜中一直位居前列。为什么C这么复杂还不被淘汰呢? C有以下优势 1、C具有高性能 C是一门编…...

内存泄漏的原因,内存泄漏如何避免?内存泄漏如何定位?
1. 内存溢出 内存溢出 OOM (out of memory),是指程序在申请内存时,没有足够的内存空间供其使用,出现out of memory;比如申请了一个int,但给它存了long才能存下的数,那就是内存溢出。 2. 内存泄…...

关于全志T113开发板接7寸LCD屏幕显示异常问题的解决方案
在入手全志T113之后,第一时间移植好了之前6ull平台的rootfs。但是在测试QT的过程中发现屏幕最右侧有一部分显示不正常,经过初步推测应该是RGB行场同步时序有问题。本以为在设备树里面稍作修改之后就能OK,但是居然前前后后一共花了至少三个星期…...

SpringMVC第四阶段:Controller中如何接收请求参数
Controller中如何接收请求参数 1、原生API参数类型 1.1、HttpServletRequest类 只需要在Controller的目标方法中, 直接写上HttpServletRequest对象即可获取 原生API的 request对象实例。 RequestMapping(value "/p1") public String param1(HttpServletRequest …...
变量 varablie 声明- Rust 变量 let mut 声明与 C/C++ 变量声明对比分析
一、变量声明设计:let 与 mut 的哲学解析 Rust 采用 let 声明变量并通过 mut 显式标记可变性,这种设计体现了语言的核心哲学。以下是深度解析: 1.1 设计理念剖析 安全优先原则:默认不可变强制开发者明确声明意图 let x 5; …...

Debian系统简介
目录 Debian系统介绍 Debian版本介绍 Debian软件源介绍 软件包管理工具dpkg dpkg核心指令详解 安装软件包 卸载软件包 查询软件包状态 验证软件包完整性 手动处理依赖关系 dpkg vs apt Debian系统介绍 Debian 和 Ubuntu 都是基于 Debian内核 的 Linux 发行版ÿ…...
Golang dig框架与GraphQL的完美结合
将 Go 的 Dig 依赖注入框架与 GraphQL 结合使用,可以显著提升应用程序的可维护性、可测试性以及灵活性。 Dig 是一个强大的依赖注入容器,能够帮助开发者更好地管理复杂的依赖关系,而 GraphQL 则是一种用于 API 的查询语言,能够提…...
Java - Mysql数据类型对应
Mysql数据类型java数据类型备注整型INT/INTEGERint / java.lang.Integer–BIGINTlong/java.lang.Long–––浮点型FLOATfloat/java.lang.FloatDOUBLEdouble/java.lang.Double–DECIMAL/NUMERICjava.math.BigDecimal字符串型CHARjava.lang.String固定长度字符串VARCHARjava.lang…...
解决本地部署 SmolVLM2 大语言模型运行 flash-attn 报错
出现的问题 安装 flash-attn 会一直卡在 build 那一步或者运行报错 解决办法 是因为你安装的 flash-attn 版本没有对应上,所以报错,到 https://github.com/Dao-AILab/flash-attention/releases 下载对应版本,cu、torch、cp 的版本一定要对…...

全志A40i android7.1 调试信息打印串口由uart0改为uart3
一,概述 1. 目的 将调试信息打印串口由uart0改为uart3。 2. 版本信息 Uboot版本:2014.07; Kernel版本:Linux-3.10; 二,Uboot 1. sys_config.fex改动 使能uart3(TX:PH00 RX:PH01),并让boo…...

python执行测试用例,allure报乱码且未成功生成报告
allure执行测试用例时显示乱码:‘allure’ �����ڲ����ⲿ���Ҳ���ǿ�&am…...

深入浅出深度学习基础:从感知机到全连接神经网络的核心原理与应用
文章目录 前言一、感知机 (Perceptron)1.1 基础介绍1.1.1 感知机是什么?1.1.2 感知机的工作原理 1.2 感知机的简单应用:基本逻辑门1.2.1 逻辑与 (Logic AND)1.2.2 逻辑或 (Logic OR)1.2.3 逻辑与非 (Logic NAND) 1.3 感知机的实现1.3.1 简单实现 (基于阈…...

Web后端基础(基础知识)
BS架构:Browser/Server,浏览器/服务器架构模式。客户端只需要浏览器,应用程序的逻辑和数据都存储在服务端。 优点:维护方便缺点:体验一般 CS架构:Client/Server,客户端/服务器架构模式。需要单独…...

tauri项目,如何在rust端读取电脑环境变量
如果想在前端通过调用来获取环境变量的值,可以通过标准的依赖: std::env::var(name).ok() 想在前端通过调用来获取,可以写一个command函数: #[tauri::command] pub fn get_env_var(name: String) -> Result<String, Stri…...