随机梯度下降的代码实现
在单变量线性回归的机器学习代码中,我们讨论了批量梯度下降代码的实现,本篇将进行随机梯度下降的代码实现,整体和批量梯度下降代码类似,仅梯度下降部分不同:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib# 导入数据
path = 'ex1data1.txt'
data = pd.read_csv(path, header=None, names=['Population', 'Profit'])# 分离特征和目标变量
X = data.iloc[:, 0:1].values # Population列
y = data.iloc[:, 1].values # Profit列
m = len(y) # 样本数量# 添加一列全为1的截距项
X = np.append(np.ones((m, 1)), X, axis=1)# 批量梯度下降参数
alpha = 0.01 # 学习率
iterations = 1500 # 迭代次数# 随机梯度下降算法
def stochasticGradientDescent(X, y, theta, alpha, num_iters):m = len(y)for iter in range(num_iters):for i in range(m):# 随机选择一个数据点进行梯度计算random_index = np.random.randint(0, m)X_i = X[random_index, :].reshape(1, X.shape[1])y_i = y[random_index].reshape(1, 1)# 计算预测值和误差prediction = np.dot(X_i, theta)error = prediction - y_i# 更新参数theta = theta - (alpha * X_i.T.dot(error)).flatten()return theta# 初始化模型参数
theta = np.zeros(2)"""
随机梯度下降前的损失显示
"""
# 定义损失函数,用于显示调用前后的损失值对比
def computeCost(X, y, theta):m = len(y)predictions = X.dot(theta)square_err = (predictions - y) ** 2return np.sum(square_err) / (2 * m)
# 计算初始损失
initial_cost = computeCost(X, y, theta)
print("初始的损失值:", initial_cost)# 使用随机梯度下降进行模型拟合
theta = stochasticGradientDescent(X, y, theta, alpha, iterations)"""
随机梯度下降后的损失显示
"""
# 计算优化后的损失
final_cost = computeCost(X, y, theta)
print("优化后的损失值:", final_cost)"""
使用需要预测的数据X进行预测
"""
# 假设的人口数据
population_values = [3.5, 7.0] # 代表35,000和70,000人口# 对每个人口值进行预测
for pop in population_values:# 将人口值转换为与训练数据相同的格式(包括截距项)predict_data = np.matrix([1, pop]) # 添加截距项# 使用模型进行预测predict_profit = np.dot(predict_data, theta.T)print(f"模型预测结果 {pop} : {predict_profit[0,0]}")
"""
使用模型绘制函数
"""
# 创建预测函数
x_values = np.array(X[:, 1])
f = theta[0] * np.ones_like(x_values) + (theta[1] * x_values) # 使用广播机制# 绘制图表
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(x_values, f, 'r', label='Prediction')
ax.scatter(data.Population, data.Profit, label='Training Data')
ax.legend(loc=2)
ax.set_xlabel('Population')
ax.set_ylabel('Profit')
ax.set_title('Predicted Profit vs. Population Size')
plt.show()"""
保存模型
"""
# 保存模型
joblib.dump(theta, 'linear_regression_model.pkl')"""
加载模型并执行预测
"""
# 加载模型
loaded_model = joblib.load('linear_regression_model.pkl')# 假设的人口数据
population_values = [3.5, 7.0] # 代表35,000和70,000人口# 使用模型进行预测
for pop in population_values:# 更新预测数据矩阵,包括当前的人口值predict_data = np.matrix([1, pop])# 进行预测predict_value = np.dot(predict_data, loaded_model.T)print(f"模型预测结果 {pop} : {predict_value[0,0]}")
实际测试下来,同迭代次数情况下随机梯度下降的收敛度远低于批量梯度下降:
初始的损失值: 32.072733877455676
优化后的损失值: 6.037742815925882 批量梯度下降为:4.47802760987997
模型预测结果 3.5 : -0.6151395665038226
模型预测结果 7.0 : 2.9916563373877203
模型预测结果 3.5 : -0.6151395665038226
模型预测结果 7.0 : 2.9916563373877203
即便是将迭代次数增加10倍也无法有效降低太多损失,15000次迭代的结果:
优化后的损失值: 5.620745223253086
个人总结:随机梯度下降估计只有针对超大规模的数据有应用意义。
注:本文为学习吴恩达版本机器学习教程的代码整理,使用的数据集为https://github.com/fengdu78/Coursera-ML-AndrewNg-Notes/blob/f2757f85b99a2b800f4c2e3e9ea967d9e17dfbd8/code/ex1-linear%20regression/ex1data1.txt
相关文章:
随机梯度下降的代码实现
在单变量线性回归的机器学习代码中,我们讨论了批量梯度下降代码的实现,本篇将进行随机梯度下降的代码实现,整体和批量梯度下降代码类似,仅梯度下降部分不同: import numpy as np import pandas as pd import matplotl…...
渐进推导中常用的一些结论
标题很帅 STAR-RIS Enhanced Joint Physical Layer Security and Covert Communications for Multi-antenna mmWave Systems文章末尾的一个推导。 lim M → ∞ ∥ Φ ( w k ⊗ Θ r ) Ω r w H g ∗ ∥ 2 2 M lim M → ∞ Tr ( g T Ω r w ( w k ⊗ Θ r ) H Φ H Φ…...
网络安全等级保护V2.0测评指标
网络安全等级保护(等保V2.0)测评指标: 1、物理和环境安全 2、网络和通信安全 3、设备和计算安全 4、应用和数据安全 5、安全策略和管理制度 6、安全管理机构和人员 7、安全建设管理 8、安全运维管理 软件全文档获取:点我获取 1、物…...
java中list的addAll用法详细实例?
List 的 addAll() 方法用于将一个集合中的所有元素添加到另一个 List 中。下面是一个详细的实例,展示了 addAll() 方法的使用: java Copy code import java.util.ArrayList; import java.util.List; public class AddAllExample { public static v…...
关于学习计算机的心得与体会
也是隔了一周没有发文了,最近一直在准备期末考试,后来想了很久,学了这么久的计算机,这当中有些收获和失去想和各位正在和我一样在学习计算机的路上的老铁分享一下,希望可以作为你们碰到困难时的良药。先叠个甲…...
LLM之RAG理论(一)| CoN:腾讯提出笔记链(CHAIN-OF-NOTE)来提高检索增强模型(RAG)的透明度
论文地址:https://arxiv.org/pdf/2311.09210.pdf 检索增强语言模型(RALM)已成为自然语言处理中一种强大的新范式。通过将大型预训练语言模型与外部知识检索相结合,RALM可以减少事实错误和幻觉,同时注入最新知识。然而&…...
Android studio:打开应用程序闪退的问题2.0
目录 找到问题分析问题解决办法 找到问题 老生常谈,可能这东西真的很常见吧,在之前那篇文章中 linkhttp://t.csdnimg.cn/UJQNb 已经谈到了关于打开Androidstuidio开发的软件后明明没有报错却无法运行(具体表现为应用程序闪退的问题ÿ…...
Spring IoC如何存取Bean对象
小王学习录 IoC(Inversion of Control)1. 什么是IoC2. 什么是Spring IoC3. 什么是DI4. Spring IoC的作用 存储Bean对象1. 创建Bean2. 将Bean注册到Spring中. 取Bean对象.1. 获取Spring上下文信息使用ApplicationContext和BeanFactory的区别 2. 获取指定Bean对象 IoC(Inversion …...
【开源】基于Vue.js的实验室耗材管理系统
文末获取源码,项目编号: S 081 。 \color{red}{文末获取源码,项目编号:S081。} 文末获取源码,项目编号:S081。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 耗材档案模块2.2 耗材入库模块2.3 耗…...
Datawhale聪明办法学Python(task2Getting Started)
一、课程基本结构 课程开源地址:课程简介 - 聪明办法学 Python 第二版 章节结构: Chapter 0 安装 InstallationChapter 1 启航 Getting StartedChapter 2 数据类型和操作 Data Types and OperatorsChapter 3 变量与函数 Variables and FunctionsChapte…...
量化交易怎么操作?量化软件怎么选择比较好?(散户福利,建议收藏)
一:量化的具体操作步骤是什么呢?1. 数据获取:索取和收集金融市场数据。 2. 策略制定:制定数量交易策略,这包括制定投资目标、建立交易规则和风险控制机制等,这个过程需要不断优化和更新。 3. 编写算法&am…...
什么是 AWS IAM?如何使用 IAM 数据库身份验证连接到 Amazon RDS(上)
驾驭云服务的安全环境可能很复杂,但 AWS IAM 为安全访问管理提供了强大的框架。在本文中,我们将探讨什么是 AWS Identity and Access Management (IAM) 以及它如何增强安全性。我们还将提供有关使用 IAM 连接到 Amazon Relational Database Service (RDS…...
Python从入门到精通七:Python函数进阶
函数多返回值 学习目标: 知道函数如何返回多个返回值 问: 如果一个函数如些两个return (如下所示),程序如何执行? 答:只执行了第一个return,原因是因为return可以退出当前函数,导致return下方的代码不执…...
uniapp踩坑之项目:使用过滤器将时间格式化为特定格式
利用filters过滤器对数据直接进行格式化,注意:与method、onLoad、data同层级 <template><div><!-- orderInfo.time的数据为:2023-12-12 12:10:23 --><p>{{ orderInfo.time | formatDate }}</p> <!-- 2023-1…...
webpack学习-2.管理资源
webpack学习-2.管理资源 1.这章要干嘛2.加载css注意顺序! 3.总结 1.这章要干嘛 管理资源,什么意思呢?管理什么资源?项目中经常会 导入各种各样的css文件,图片文件,字体文件,数据文件等等&#…...
658. 找到 K 个最接近的元素
658. 找到 K 个最接近的元素 Java代码:滑窗 class Solution {public List<Integer> findClosestElements(int[] arr, int k, int x) {List<Integer> list new ArrayList<>();for (int i 0; i < arr.length; i) {arr[i] arr[i] - x;}for(i…...
十二、MapReduce概述
1、MapReduce (1)采用框架 MapReduce是“分散——>汇总”模式的分布式计算框架,可供开发人员进行相应计算 (2)编程接口: ~Map ~Reduce 其中,Map功能接口提供了“分散”的功能ÿ…...
shell条件测试
目录 1.1.用途 1.2.基本语法 1.2.1.格式: 1.2.2.例 1.3 文件测试 1.4.整数测试 1.4.1.作用 1.4.2.操作符 1.4.3.示例: 1.5.逻辑操作符 1.5.1.符号 1.5.2.例: 1.6.命令分隔符 1.1.用途 为了能够正确处理Shell程序运行过程中遇到的各种情况&am…...
python在线读取传奇列表,并解析为需要的JSON格式
python在线读取传奇列表,并解析为需要的JSON格式,以下为传奇中使用的TXT列表格式, [Server] ; 使用“/”字符分开颜色,也可以不使用颜色,支持以前的旧格式,只有标题和服务器标题支持颜色 ; 标题/颜色代码(0-255)|服务器标题/颜色代码(0-255)|服务器名称|服务器IP|服务器端…...
【docker 】 安装docker(centOS7)
官网 docker官网 github源码 官网 在CentOS上安装Docker引擎 官网 在Debian上安装Docker引擎 官网 在 Fedora上安装Docker引擎 官网 在ubuntu上安装Docker引擎 官网 在RHEL (s390x)上安装Docker引擎 官网 在SLES上安装Docker引擎 最完善的资料都在官网。 卸载旧版本 …...
Python开发者三步完成Taotoken接入并调用多模型
🚀 告别海外账号与网络限制!稳定直连全球优质大模型,限时半价接入中。 👉 点击领取海量免费额度 Python开发者三步完成Taotoken接入并调用多模型 对于希望便捷使用多种大语言模型的Python开发者而言,通过一个统一的AP…...
别再为VMware里Kali上不了网发愁了!三种网络模式(桥接/NAT/仅主机)保姆级配置与排错指南
VMware中Kali Linux网络配置全攻略:从原理到实战排错 当你第一次在VMware中启动Kali Linux准备大展身手时,却发现连最基本的网络连接都无法建立——这种挫败感我深有体会。作为网络安全学习和渗透测试的必备工具,Kali在虚拟机中的网络配置往往…...
告别繁琐配置!用EB和S32DS快速搭建AutoSar MCAL基础工程(附完整文件结构解析)
从零构建AutoSar MCAL工程:EB与S32DS高效协作实战指南 当第一次打开AutoSar MCAL的官方示例工程时,多数工程师都会被密密麻麻的文件夹和配置文件淹没。Base、Platform、ECUC、MemIf等模块交织在一起,而EB生成的generate文件夹里又充斥着大量看…...
别再问客服了!手把手教你用VNC在AutoDL GPU服务器上跑起你的第一个GUI程序
云端GPU服务器VNC实战:从零部署GUI开发环境全指南 租用云GPU服务器进行深度学习训练已成为算法工程师的常态,但当代码涉及图形界面时,许多开发者会在cv2.imshow()或PyQt窗口弹出的环节卡壳。本文将基于AutoDL平台,详解如何通过Tur…...
保姆级教程:把Windows系统装进固态U盘,用云固件打造随身移动办公神器
随身Windows系统:用固态U盘打造移动办公终极解决方案 咖啡馆的午后阳光斜照在键盘上,你从包里掏出一个名片大小的设备,插入陌生电脑的USB接口。30秒后,熟悉的桌面环境、未写完的文档、收藏夹里的书签全部跃然屏上——这不是科幻场…...
RT-Thread开发者大会技术解析:从RTOS内核到AIoT平台实战指南
1. 项目概述:一场国产嵌入式技术的年度盛会 2021年的RT-Thread开发者大会,对于当时国内嵌入式软件圈的从业者来说,绝对是一个绕不开的关键节点。那一年,整个行业正处在一个微妙的转折期:一方面,芯片供应链…...
Fast-GitHub架构解析:基于Manifest V3的浏览器扩展网络加速方案
Fast-GitHub架构解析:基于Manifest V3的浏览器扩展网络加速方案 【免费下载链接】Fast-GitHub 国内Github下载很慢,用上了这个插件后,下载速度嗖嗖嗖的~! 项目地址: https://gitcode.com/gh_mirrors/fa/Fast-GitHub 技术架…...
数据分析篇---U型关系与与阈值效应
在数据科学、经济学和医学研究中,“U型关系”和“阈值效应”是两种非常经典且重要的非线性模式。它们描述的是变量之间并非简单的“越多越好”的直线关系,而是存在转折点。可以把线性关系想象成匀速开车,而U型和阈值效应则像是开车时遇到的上…...
从游戏到科研:手把手教你设计并运行一个n-back工作记忆测试
从游戏到科研:手把手教你设计并运行一个n-back工作记忆测试 工作记忆是人类认知功能的核心组成部分,它直接影响着我们的学习、推理和问题解决能力。在心理学和认知科学领域,n-back任务已经成为评估工作记忆容量的黄金标准之一。本文将带你从零…...
EasyWatermark代码架构详解:MVVM模式与依赖注入实践
EasyWatermark代码架构详解:MVVM模式与依赖注入实践 【免费下载链接】EasyWatermark 🔒 🖼 Securely, easily add a watermark to your sensitive photos. 安全、简单地为你的敏感照片添加水印,防止被人泄露、利用 项目地址: ht…...
