当前位置: 首页 > news >正文

深入探讨梯度下降:优化机器学习的关键步骤(二)

文章目录

  • 🍀引言
  • 🍀eta参数的调节
  • 🍀sklearn中的梯度下降

🍀引言

承接上篇,这篇主要有两个重点,一个是eta参数的调解;一个是在sklearn中实现梯度下降

在梯度下降算法中,学习率(通常用符号η表示,也称为步长或学习速率)的选择非常重要,因为它直接影响了算法的性能和收敛速度。学习率控制了每次迭代中模型参数更新的幅度。以下是学习率(η)的重要性:

  • 收敛速度:学习率决定了模型在每次迭代中移动多远。如果学习率过大,模型可能会在参数空间中来回摇摆,导致不稳定的收敛或甚至发散。如果学习率过小,模型将收敛得很慢,需要更多的迭代次数才能达到最优解。因此,选择合适的学习率可以加速收敛速度。

  • 稳定性:过大的学习率可能会导致梯度下降算法不稳定,甚至无法收敛。过小的学习率可以使算法更加稳定,但可能需要更多的迭代次数才能达到最优解。因此,合适的学习率可以在稳定性和收敛速度之间取得平衡。

  • 避免局部最小值:选择不同的学习率可能会导致模型陷入不同的局部最小值。通过尝试不同的学习率,您可以更有可能找到全局最小值,而不是被困在局部最小值中。

  • 调优:学习率通常需要调优。您可以尝试不同的学习率值,并监视损失函数的收敛情况。通常,您可以使用学习率衰减策略,逐渐降低学习率以改善收敛性能。

  • 批量大小:学习率的选择也与批量大小有关。通常,小批量梯度下降(Mini-batch Gradient Descent)使用比大批量梯度下降更大的学习率,因为小批量可以提供更稳定的梯度估计。

总之,学习率是梯度下降算法中的关键超参数之一,它需要仔细选择和调整,以在训练过程中实现最佳性能和收敛性。不同的问题和数据集可能需要不同的学习率,因此在实践中,通常需要进行实验和调优来找到最佳的学习率值。


🍀eta参数的调节

在上代码前我们需要知道,如果eta的值过小会造成什么样的结果

在这里插入图片描述
反之如果过大呢

在这里插入图片描述
可见,eta过大过小都会影响效率,所以一个合适的eta对于寻找最优有着至关重要的作用


在上篇的学习中我们已经初步完成的代码,这篇我们将其封装一下
首先需要定义两个函数,一个用来返回thera的历史列表,一个则将其绘制出来

def gradient_descent(eta,initial_theta,epsilon = 1e-8):theta = initial_thetatheta_history = [initial_theta]def dj(theta): return 2*(theta-2.5) #  传入theta,求theta点对应的导数def j(theta):return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breakreturn theta_historydef plot_gradient(theta_history):plt.plot(plt_x,plt_y)plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')plt.show()

其实就是上篇代码的整合罢了
之后我们需要进行简单的调参了,这里我们分别采用0.10.010.9,这三个参数进行调节

eta = 0.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.01
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述

eta = 0.9
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
这三张图与之前的提示很像吧,可见调参的重要性
如果我们将eta改为1.0呢,那么会发生什么

eta = 1.0
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
那改为1.1呢

eta = 1.1
theta =0.0
plot_gradient(gradient_descent(eta,theta))
len(theta_history)

运行结果如下
在这里插入图片描述
我们从图可以清楚的看到,当eta为1.1的时候是嗷嗷增大的,这种情况我们需要采用异常处理来限制一下,避免报错,处理的方式是限制循环的最大值,且可以在expect中设置inf(正无穷)

def gradient_descent(eta,initial_theta,n_iters=1e3,epsilon = 1e-8):theta = initial_thetatheta_history = [initial_theta]i_iter = 1def dj(theta):  try:return 2*(theta-2.5) #  传入theta,求theta点对应的导数except:return float('inf')def j(theta):return (theta-2.5)**2-1  #  传入theta,获得目标函数的对应值while i_iter<=n_iters:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breaki_iter+=1return theta_historydef plot_gradient(theta_history):plt.plot(plt_x,plt_y)plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='+')plt.show()

注意:inf表示正无穷大


🍀sklearn中的梯度下降

这里我们还是以波士顿房价为例子
首先导入需要的库

from sklearn.datasets import load_boston
from sklearn.linear_model import SGDRegressor

之后取一部分的数据

boston = load_boston()
X = boston.data
y = boston.target
X = X[y<50]
y = y[y<50]

然后进行数据归一化

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(X,y)
std = StandardScaler()
std.fit(X_train)
X_train_std=std.transform(X_train)
X_test_std=std.transform(X_test)
sgd_reg = SGDRegressor()
sgd_reg.fit(X_train_std,y_train)

最后取得score

sgd_reg.score(X_test_std,y_test)

运行结果如下
在这里插入图片描述


请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

相关文章:

深入探讨梯度下降:优化机器学习的关键步骤(二)

文章目录 &#x1f340;引言&#x1f340;eta参数的调节&#x1f340;sklearn中的梯度下降 &#x1f340;引言 承接上篇&#xff0c;这篇主要有两个重点&#xff0c;一个是eta参数的调解&#xff1b;一个是在sklearn中实现梯度下降 在梯度下降算法中&#xff0c;学习率&#xf…...

高频算法面试题

合并两个有序数组 const merge (nums1, nums2) > {let p1 0;let p2 0;const result [];let cur;while (p1 < nums1.length || p2 < nums2.length) {if (p1 nums1.length) {cur nums2[p2];} else if (p2 nums2.length) {cur nums1[p1];} else if (nums1[p1] &…...

Hive-启动与操作(2)

&#x1f947;&#x1f947;【大数据学习记录篇】-持续更新中~&#x1f947;&#x1f947; 个人主页&#xff1a;beixi 本文章收录于专栏&#xff08;点击传送&#xff09;&#xff1a;【大数据学习】 &#x1f493;&#x1f493;持续更新中&#xff0c;感谢各位前辈朋友们支持…...

css transition 指南

css transition 指南 在本文中&#xff0c;我们将深入了解 CSS transition&#xff0c;以及如何使用它们来创建丰富、精美的动画。 基本原理 我们创建动画时通常需要一些动画相关的 CSS。 下面是一个按钮在悬停时移动但没有动画的示例&#xff1a; <button class"…...

LeetCode 面试题 02.05. 链表求和

文章目录 一、题目二、C# 题解 一、题目 给定两个用链表表示的整数&#xff0c;每个节点包含一个数位。 这些数位是反向存放的&#xff0c;也就是个位排在链表首部。 编写函数对这两个整数求和&#xff0c;并用链表形式返回结果。 点击此处跳转题目。 示例&#xff1a; 输入&a…...

一米脸书营销软件

功能优势 JOIN ADVANTAGE HOME PAGE MARKETING 公共主页营销 可同时对多个账户公共主页评论&#xff0c;点赞等 可批量邀请多个好友对Facebook公共主页进行评论点赞等&#xff0c;也可批量登录小号对自己公共主页进行点赞。 GROUP MARKETING 小组营销 可批量针对不同账户进行…...

vue 根据数值判断颜色

1.首先style样式给两种颜色 用:class 三元运算符判断出一种颜色 第一步&#xff1a;在style里边设置两种颜色 .green{color: green; } .orange{color: orangered; }在取数据的标签 里边 判断一种颜色 :class"item.quote.current >0 ?orange: green"<van-gri…...

Hugging Face 实战系列 总目录

PyTorch 深度学习 开发环境搭建 全教程 Transformer:《Attention is all you need》 Hugging Face简介 1、Hugging Face实战-系列教程1&#xff1a;Tokenizer分词器&#xff08;Transformer工具包/自然语言处理&#xff09; Hungging Face实战-系列教程1&#xff1a;Tokenize…...

国标视频云服务EasyGBS国标视频平台迁移服务器后无法启动的问题解决方法

国标视频云服务EasyGBS支持设备/平台通过国标GB28181协议注册接入&#xff0c;并能实现视频的实时监控直播、录像、检索与回看、语音对讲、云存储、告警、平台级联等功能。平台部署简单、可拓展性强&#xff0c;支持将接入的视频流进行全终端、全平台分发&#xff0c;分发的视频…...

HTML <th> 标签

实例 普通的 HTML 表格,包含两行两列: <table border="1"><tr><th>Company</th><th>Address</th></tr><tr><td>Apple, Inc.</td><td>1 Infinite Loop Cupertino, CA 95014</td></tr…...

HTTP/1.1协议中的响应报文

2023年8月30日&#xff0c;周三下午 目录 概述响应报文示例详述 概述 HTTP/1.1协议的响应报文由以下几个部分组成&#xff1a; 状态行&#xff08;Status Line&#xff09;响应头部&#xff08;Response Headers&#xff09;空行&#xff08;Blank Line&#xff09;响应体&a…...

TDengine函数大全-选择函数

以下内容来自 TDengine 官方文档 及 GitHub 内容 。 以下所有示例基于 TDengine 3.1.0.3 TDengine函数大全 1.数学函数 2.字符串函数 3.转换函数 4.时间和日期函数 5.聚合函数 6.选择函数 7.时序数据库特有函数 8.系统函数 选择函数 TDengine函数大全BOTTOMFIRSTINTERPLASTLAS…...

非关系型数据库Redis的安装

一、关系型数据库与非关系型数据库的区别&#xff1a;---------面试高频率问题 1、首先了解一下 什么是关系型数据库&#xff1f; 关系型数据库最典型的数据结构是表&#xff0c;由二维表及其之间的联系所组成的一个数据组织。 优点&#xff1a; 易于维护&#xff1a;都是使用…...

oracle 创建数据库

查询表空间的命令 select t1.name,t2.name from v$tablespace t1,v$datafile t2 where t1.ts# t2.ts#; CREATE TABLESPACE ORM_342_BETA DATAFILE /app/oracle/oradata/sysware/ORM_342_BETA.DBF size 800M --存储地址 初始大小800M autoextend on nex…...

wxWidgets从空项目开始Hello World

前文回顾 接上篇&#xff0c;已经是在CodeBlocks20.03配置了wxWidgets3.0.5&#xff0c;并且能够通过项目创建导航创建一个新的工程&#xff0c;并且成功运行。 那么上一个是通过CodeBlocks的模板创建的&#xff0c;一进去就已经是2个头文件2个cpp文件&#xff0c;总是感觉缺…...

【Apollo学习笔记】——规划模块TASK之SPEED_DECIDER

文章目录 前言SPEED_DECIDER功能简介SPEED_DECIDER相关配置SPEED_DECIDER流程MakeObjectDecisionGetSTLocationCheck类函数CheckKeepClearCrossableCheckStopForPedestrianCheckIsFollowCheckKeepClearBlocked Create类函数 前言 在Apollo星火计划学习笔记——Apollo路径规划算…...

【操作系统】一文快速入门,很适合JAVA后端看

作者简介&#xff1a; 目录 1.概述 2.CPU管理 3.内存管理 4.IO管理 1.概述 操作系统可以看作一个计算机的管理系统&#xff0c;对计算机的硬件资源提供了一套完整的管理解决方案。计算机的硬件组成有五大模块&#xff1a;运算器、控制器、存储器、输入设备、输出设备。操作…...

C++ Primer阅读笔记--allocator类的使用

1--allocator类的使用背景 new 在分配内存时具有一定的局限性&#xff0c;其将内存分配和对象构造组合在一起&#xff1b;当分配一大块内存时&#xff0c;一般希望可以在内存上按需构造对象&#xff0c;这时需要将内存分配和对象构造分离&#xff0c;而定义在头文件 memory 的 …...

【C++历险记】面向对象|菱形继承及菱形虚拟继承

个人主页&#xff1a;兜里有颗棉花糖&#x1f4aa; 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【C之路】&#x1f48c; 本专栏旨在记录C的学习路线&#xff0c;望对大家有所帮助&#x1f647;‍ 希望我们一起努力、成长&…...

【Locomotor运动模块】攀爬

文章目录 一、攀爬主体“伪身体”1、“伪身体”的设置2、“伪身体”和“真实身体”&#xff0c;为什么同步移动3、“伪身体”和“真实身体”&#xff0c;碰到墙时不同步的原因①现象②原因③解决 二、攀爬1、需要的组件&#xff1a;“伪身体”、Climbing、Climbable及Interacto…...

练习(含atoi的模拟实现,自定义类型等练习)

一、结构体大小的计算及位段 &#xff08;结构体大小计算及位段 详解请看&#xff1a;自定义类型&#xff1a;结构体进阶-CSDN博客&#xff09; 1.在32位系统环境&#xff0c;编译选项为4字节对齐&#xff0c;那么sizeof(A)和sizeof(B)是多少&#xff1f; #pragma pack(4)st…...

ESP32读取DHT11温湿度数据

芯片&#xff1a;ESP32 环境&#xff1a;Arduino 一、安装DHT11传感器库 红框的库&#xff0c;别安装错了 二、代码 注意&#xff0c;DATA口要连接在D15上 #include "DHT.h" // 包含DHT库#define DHTPIN 15 // 定义DHT11数据引脚连接到ESP32的GPIO15 #define D…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

Hive 存储格式深度解析:从 TextFile 到 ORC,如何选对数据存储方案?

在大数据处理领域&#xff0c;Hive 作为 Hadoop 生态中重要的数据仓库工具&#xff0c;其存储格式的选择直接影响数据存储成本、查询效率和计算资源消耗。面对 TextFile、SequenceFile、Parquet、RCFile、ORC 等多种存储格式&#xff0c;很多开发者常常陷入选择困境。本文将从底…...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)

船舶制造装配管理现状&#xff1a;装配工作依赖人工经验&#xff0c;装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书&#xff0c;但在实际执行中&#xff0c;工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...

Fabric V2.5 通用溯源系统——增加图片上传与下载功能

fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

C++:多态机制详解

目录 一. 多态的概念 1.静态多态&#xff08;编译时多态&#xff09; 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1&#xff09;.协变 2&#xff09;.析构函数的重写 5.override 和 final关键字 1&#…...

Linux 中如何提取压缩文件 ?

Linux 是一种流行的开源操作系统&#xff0c;它提供了许多工具来管理、压缩和解压缩文件。压缩文件有助于节省存储空间&#xff0c;使数据传输更快。本指南将向您展示如何在 Linux 中提取不同类型的压缩文件。 1. Unpacking ZIP Files ZIP 文件是非常常见的&#xff0c;要在 …...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...