Pytorch-SGD算法解析
关注B站可以观看更多实战教学视频:肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com)
SGD,即随机梯度下降(Stochastic Gradient Descent),是机器学习中用于优化目标函数的迭代方法,特别是在处理大数据集和在线学习场景中。与传统的批量梯度下降(Batch Gradient Descent)不同,SGD在每一步中仅使用一个样本来计算梯度并更新模型参数,这使得它在处理大规模数据集时更加高效。
SGD算法的基本步骤
- 初始化参数:选择初始参数值,可以是随机的或者基于一些先验知识。
- 随机选择样本:从数据集中随机选择一个样本。
- 计算梯度:计算损失函数关于当前参数的梯度。
- 更新参数:沿着负梯度方向更新参数。
- 重复:重复步骤2-4,直到满足停止条件(如达到预设的迭代次数或损失函数的改变小于某个阈值)。
SGD的Python代码示例:
python实现
假设我们要使用SGD来优化一个简单的线性回归模型。
import numpy as np # 目标函数(损失函数)和其梯度
def loss_function(w, b, x, y): return np.sum((y - (w * x + b)) ** 2) / len(x) def gradient_function(w, b, x, y): dw = -2 * np.sum((y - (w * x + b)) * x) / len(x) db = -2 * np.sum(y - (w * x + b)) / len(x) return dw, db # SGD算法
def sgd(x, y, learning_rate=0.01, epochs=1000): # 初始化参数 w = np.random.rand() b = np.random.rand() # 存储每次迭代的损失值,用于可视化 losses = [] for i in range(epochs): # 随机选择一个样本(在这个示例中,我们没有实际进行随机选择,而是使用了整个数据集。在大数据集上,你应该随机选择一个样本或小批量样本。) # 注意:为了简化示例,这里我们实际上使用的是批量梯度下降。在真正的SGD中,你应该在这里随机选择一个样本。 # 计算梯度 dw, db = gradient_function(w, b, x, y) # 更新参数 w = w - learning_rate * dw b = b - learning_rate * db # 记录损失值 loss = loss_function(w, b, x, y) losses.append(loss) # 每隔一段时间打印损失值(可选) if i % 100 == 0: print(f"Epoch {i}, Loss: {loss}") return w, b, losses # 示例数据(你可以替换为自己的数据)
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 4, 6, 8, 10]) # 运行SGD算法
w, b, losses = sgd(x, y)
print(f"Optimized parameters: w = {w}, b = {b}")
解析
- 在上面的代码中,我们首先定义了损失函数和它的梯度。对于线性回归,损失函数通常是均方误差。
sgd函数实现了SGD算法。它接受输入数据x和标签y,以及学习率和迭代次数作为参数。- 在每次迭代中,我们计算损失函数关于参数
w和b的梯度,并使用这些梯度来更新参数。 - 我们还记录了每次迭代的损失值,以便稍后可视化算法的收敛情况。
- 最后,我们打印出优化后的参数值。在实际应用中,你可能还需要使用这些参数来对新数据进行预测。
在PyTorch中,SGD(随机梯度下降)是一种基本的优化器,用于调整模型的参数以最小化损失函数。下面是torch.optim.SGD的参数解析和一个简单的用例。
SGD的Pytorch代码示例:
参数解析
torch.optim.SGD的主要参数如下:
- params (iterable):待优化的参数,或者是定义了参数的模型的迭代器。
- lr (float):学习率。这是更新参数的步长大小。较小的值会导致更新更精细,而较大的值可能会导致训练过程不稳定。这是SGD优化器的一个关键参数。
- momentum (float, optional):动量因子 (default: 0)。该参数加速了SGD在相关方向上的收敛,并抑制了震荡。
- dampening (float, optional):动量的抑制因子 (default: 0)。增加此值可以减少动量的影响。在实际应用中,这个参数的使用较少。
- weight_decay (float, optional):权重衰减 (L2 penalty) (default: 0)。通过向损失函数添加与权重向量平方成比例的惩罚项,来防止过拟合。
- nesterov (bool, optional):是否使用Nesterov动量 (default: False)。Nesterov动量是标准动量方法的一个变种,它在计算梯度时使用了未来的近似位置。
用例
下面是一个使用SGD优化器的简单例子:
import torch
import torch.nn as nn
import torch.optim as optim # 定义一个简单的模型
model = nn.Sequential( nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2),
) # 定义损失函数
criterion = nn.CrossEntropyLoss() # 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001) # 假设有输入数据和目标
input_data = torch.randn(1, 10)
target = torch.tensor([1]) # 训练循环(这里只展示了一次迭代)
for epoch in range(1): # 通常会有多个 epochs # 前向传播 output = model(input_data) # 计算损失 loss = criterion(output, target) # 反向传播 optimizer.zero_grad() # 清除之前的梯度 loss.backward() # 计算当前梯度 # 更新参数 optimizer.step() # 应用梯度更新 # 打印损失 print(f'Epoch {epoch+1}, Loss: {loss.item()}')
在这个例子中,我们创建了一个简单的两层神经网络模型,并使用SGD优化器来更新模型的参数。在训练循环中,我们执行了前向传播来计算模型的输出,然后计算了损失,通过调用loss.backward()执行了反向传播来计算梯度,最后通过调用optimizer.step()更新了模型的参数。在每次迭代开始时,我们使用optimizer.zero_grad()来清除之前累积的梯度,这是非常重要的步骤,因为PyTorch默认会累积梯度。
相关文章:
Pytorch-SGD算法解析
关注B站可以观看更多实战教学视频:肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com) SGD,即随机梯度下降(Stochastic Gradient Descent),是机器学习中用于优化目标函数的迭代方法,特别是在处…...
物联网土壤传感器简介
物联网土壤传感器简介 物联网土壤传感器的工作原理基于多种物理、化学和生物原理,通过感应器等组成部件将土壤中的特征数据转化为电信号,从而进行采集、处理和输出。这些传感器主要包括土壤湿度传感器、土壤温度传感器、土壤酸碱度传感器和土壤颗粒物传…...
MySQL索引面试题(高频)
文章目录 前言什么时候需要(不需要))使用索引?有哪些优化索引的方法前缀索引优化索引覆盖优化索引失效场景 总结 前言 今天来讲一讲 MySQL 索引的高频面试题。主要是针对前一篇文章 MySQL索引入门(一文搞定)进行查漏补…...
SouthLeetCode-打卡24年02月第2周
SouthLeetCode-打卡24年02月第2周 // Date : 2024/02/05 ~ 2024/02/11 039.有效的字母异位词 (1) 题目描述 039#LeetCode.242.简单题目链接#Monday2024/02/05 给定两个字符串 *s* 和 *t* ,编写一个函数来判断 *t* 是否是 *s* 的字母异位词。 **注意࿱…...
Rust CallBack的几种写法
模拟常用的几种函数调用CallBack的写法。测试调用都放在函数t6_call_back_task中。我正在学习Rust,有不对或者欠缺的地方,欢迎交流指正 type Callback std::sync::Arc<dyn Fn() Send Sync>; type CallbackReturnVal std::sync::Arc<dyn Fn…...
Redis突现拒绝连接问题处理总结
一、问题回顾 项目突然报异常 [INFO] 2024-02-20 10:09:43.116 i.l.core.protocol.ConnectionWatchdog [171]: Reconnecting, last destination was 192.168.0.231:6379 [WARN] 2024-02-20 10:09:43.120 i.l.core.protocol.ConnectionWatchdog [151]: Cannot reconnect…...
css中选择器的优先级
CSS 的优先级是由选择器的特指度(Specificity)和重要性(Importance)决定的,以下是优先级规则: 特指度: ID 选择器 (#id): 每个ID选择器计为100。 类选择器 (.class)、属性选择器 ([attr]) 和伪…...
python3字符串内建方法split()心得
python3字符串内建方法split()心得 概念 用指定分隔符(默认是任何空白字符)将字符串拆分成列表。 语法 string.split(separator.max) 参数1.split(参数2,参数3) 参数1:string 字符串,需要被拆分的字符串。 参数2&a…...
html的列表标签
列表标签 列表在html里面经常会用到的,主要使用来布局的,使其整齐好看. 无序列表 无序列表[重要]: ul ,li 示例代码1: 对应的效果: 无序列表的属性 属性值描述typedisc,square,…...
【Pytorch深度学习开发实践学习】B站刘二大人课程笔记整理lecture04反向传播
lecture04反向传播 课程网址 Pytorch深度学习实践 部分课件内容: import torchx_data [1.0,2.0,3.0] y_data [2.0,4.0,6.0] w torch.tensor([1.0]) w.requires_grad Truedef forward(x):return x*wdef loss(x,y):y_pred forward(x)return (y_pred-y)**2…...
PyTorch使用Tricks:学习率衰减 !!
文章目录 前言 1、指数衰减 2、固定步长衰减 3、多步长衰减 4、余弦退火衰减 5、自适应学习率衰减 6、自定义函数实现学习率调整:不同层不同的学习率 前言 在训练神经网络时,如果学习率过大,优化算法可能会在最优解附近震荡而无法收敛&#x…...
10MARL深度强化学习 Value Decomposition in Common-Reward Games
文章目录 前言1、价值分解的研究现状2、Individual-Global-Max Property3、Linear and Monotonic Value Decomposition3.1线性值分解3.2 单调值分解 前言 中心化价值函数能够缓解一些多智能体强化学习当中的问题,如非平稳性、局部可观测、信用分配与均衡选择等问题…...
2 Nacos适配达梦数据库实现方案
1、修改源代码方式 Nacos 原生是不支持达梦数据库的,所以就要想办法让它 “支持”,因为是开源软件,我们可以从源码入手,在流行的 1.x 、2.x 或最新版本代码的基本上进行修改。 主要涉及到以下内容的修改: com/alibaba/nacos/persistence/datasource/ExternalDataS...
【Gitea】配置 Push To Create
引 在 Git 代码管理工具使用过程中,经常需要将一个文件夹作为仓库上传到一个未创建的代码仓库。如果 Git 服务端使用的是 Gitea,通常会推送失败。 PS D:\tmp\git-test> git remote add origin http://192.1.1.1:3000/root/git-test.git PS D:\tmp\g…...
关于postgresql数据库单独设置某个用户日志级别(日志审计)
前言: 很多时候我们想让数据库日志打印详细一点,但是又担心会对数据库本身产生一些不可控的影响,还会担心数据库产生的庞大的日志导致主机资源不太够用的影响。那么今天我们就通过讲解给单个用户设置 log_statement来解决以上这些问题。 注…...
阿里云ECS香港服务器性能强大、cn2高速网络租用价格表
阿里云香港服务器中国香港数据中心网络线路类型BGP多线精品,中国电信CN2高速网络高质量、大规格BGP带宽,运营商精品公网直连中国内地,时延更低,优化海外回中国内地流量的公网线路,可以提高国际业务访问质量。阿里云服务…...
实战打靶集锦-025-HackInOS
文章目录 1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查5. 提权5.1 枚举系统信息5.2 探索一下passwd5.3 枚举可执行文件5.4 查看capabilities位5.5 目录探索5.6 枚举定时任务5.7 Linpeas提权 靶机地址:https://download.vulnhub.com/hackinos/HackInOS.ova 1. 主机…...
list.stream().forEach()和list.forEach()的区别
list.stream().forEach() 和 list.forEach() 在 Java 中都是用于遍历集合元素的方法,但它们在使用场景和功能上有所不同: list.forEach(): 是从 Java 8 开始引入到 java.util.List 接口的标准方法。直接对列表进行迭代,它采用内部…...
JS基础之JSON对象
JS基础之JSON对象 目录 JS基础之JSON对象对象转JSON字符串JSON转JS对象 对象转JSON字符串 JSON.stringify(value,replacer,space) value:要转换的JS对象 replacer:(可选)用于过滤和转换结果的函数或数组 space:(可选)指定缩进量 // 创建JS对象 let date {name:"张三…...
嵌入式学习之Linux入门篇——使用VMware创建Unbuntu虚拟机
目录 主机硬件要求 VMware 安装 安装Unbuntu 18.04.6 LTS 新建虚拟机 进入Unbuntu安装环节 主机硬件要求 内存最少16G 硬盘最好分出一个单独的盘,而且最少预留200G,可以使用移动固态操作系统win7/10/11 VMware 安装 版本:VMware Works…...
观成科技:隐蔽隧道工具Ligolo-ng加密流量分析
1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具,该工具基于TUN接口实现其功能,利用反向TCP/TLS连接建立一条隐蔽的通信信道,支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式,适应复杂网…...
测试微信模版消息推送
进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...
Docker 离线安装指南
参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性,不同版本的Docker对内核版本有不同要求。例如,Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本,Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...
UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
大数据零基础学习day1之环境准备和大数据初步理解
学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 (1)设置网关 打开VMware虚拟机,点击编辑…...
家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...
oracle与MySQL数据库之间数据同步的技术要点
Oracle与MySQL数据库之间的数据同步是一个涉及多个技术要点的复杂任务。由于Oracle和MySQL的架构差异,它们的数据同步要求既要保持数据的准确性和一致性,又要处理好性能问题。以下是一些主要的技术要点: 数据结构差异 数据类型差异ÿ…...
屋顶变身“发电站” ,中天合创屋面分布式光伏发电项目顺利并网!
5月28日,中天合创屋面分布式光伏发电项目顺利并网发电,该项目位于内蒙古自治区鄂尔多斯市乌审旗,项目利用中天合创聚乙烯、聚丙烯仓库屋面作为场地建设光伏电站,总装机容量为9.96MWp。 项目投运后,每年可节约标煤3670…...
Cinnamon修改面板小工具图标
Cinnamon开始菜单-CSDN博客 设置模块都是做好的,比GNOME简单得多! 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...
JavaScript基础-API 和 Web API
在学习JavaScript的过程中,理解API(应用程序接口)和Web API的概念及其应用是非常重要的。这些工具极大地扩展了JavaScript的功能,使得开发者能够创建出功能丰富、交互性强的Web应用程序。本文将深入探讨JavaScript中的API与Web AP…...
