当前位置: 首页 > news >正文

Pytorch-SGD算法解析

关注B站可以观看更多实战教学视频:肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com)

SGD,即随机梯度下降(Stochastic Gradient Descent),是机器学习中用于优化目标函数的迭代方法,特别是在处理大数据集和在线学习场景中。与传统的批量梯度下降(Batch Gradient Descent)不同,SGD在每一步中仅使用一个样本来计算梯度并更新模型参数,这使得它在处理大规模数据集时更加高效。

SGD算法的基本步骤

  1. 初始化参数:选择初始参数值,可以是随机的或者基于一些先验知识。
  2. 随机选择样本:从数据集中随机选择一个样本。
  3. 计算梯度:计算损失函数关于当前参数的梯度。
  4. 更新参数:沿着负梯度方向更新参数。
  5. 重复:重复步骤2-4,直到满足停止条件(如达到预设的迭代次数或损失函数的改变小于某个阈值)。

SGD的Python代码示例:

python实现

假设我们要使用SGD来优化一个简单的线性回归模型。

import numpy as np  # 目标函数(损失函数)和其梯度  
def loss_function(w, b, x, y):  return np.sum((y - (w * x + b)) ** 2) / len(x)  def gradient_function(w, b, x, y):  dw = -2 * np.sum((y - (w * x + b)) * x) / len(x)  db = -2 * np.sum(y - (w * x + b)) / len(x)  return dw, db  # SGD算法  
def sgd(x, y, learning_rate=0.01, epochs=1000):  # 初始化参数  w = np.random.rand()  b = np.random.rand()  # 存储每次迭代的损失值,用于可视化  losses = []  for i in range(epochs):  # 随机选择一个样本(在这个示例中,我们没有实际进行随机选择,而是使用了整个数据集。在大数据集上,你应该随机选择一个样本或小批量样本。)  # 注意:为了简化示例,这里我们实际上使用的是批量梯度下降。在真正的SGD中,你应该在这里随机选择一个样本。  # 计算梯度  dw, db = gradient_function(w, b, x, y)  # 更新参数  w = w - learning_rate * dw  b = b - learning_rate * db  # 记录损失值  loss = loss_function(w, b, x, y)  losses.append(loss)  # 每隔一段时间打印损失值(可选)  if i % 100 == 0:  print(f"Epoch {i}, Loss: {loss}")  return w, b, losses  # 示例数据(你可以替换为自己的数据)  
x = np.array([1, 2, 3, 4, 5])  
y = np.array([2, 4, 6, 8, 10])  # 运行SGD算法  
w, b, losses = sgd(x, y)  
print(f"Optimized parameters: w = {w}, b = {b}")

解析

  • 在上面的代码中,我们首先定义了损失函数和它的梯度。对于线性回归,损失函数通常是均方误差。
  • sgd函数实现了SGD算法。它接受输入数据x和标签y,以及学习率和迭代次数作为参数。
  • 在每次迭代中,我们计算损失函数关于参数wb的梯度,并使用这些梯度来更新参数。
  • 我们还记录了每次迭代的损失值,以便稍后可视化算法的收敛情况。
  • 最后,我们打印出优化后的参数值。在实际应用中,你可能还需要使用这些参数来对新数据进行预测。

在PyTorch中,SGD(随机梯度下降)是一种基本的优化器,用于调整模型的参数以最小化损失函数。下面是torch.optim.SGD的参数解析和一个简单的用例。

SGD的Pytorch代码示例:

参数解析

torch.optim.SGD的主要参数如下:

  1. params (iterable):待优化的参数,或者是定义了参数的模型的迭代器。
  2. lr (float):学习率。这是更新参数的步长大小。较小的值会导致更新更精细,而较大的值可能会导致训练过程不稳定。这是SGD优化器的一个关键参数。
  3. momentum (float, optional):动量因子 (default: 0)。该参数加速了SGD在相关方向上的收敛,并抑制了震荡。
  4. dampening (float, optional):动量的抑制因子 (default: 0)。增加此值可以减少动量的影响。在实际应用中,这个参数的使用较少。
  5. weight_decay (float, optional):权重衰减 (L2 penalty) (default: 0)。通过向损失函数添加与权重向量平方成比例的惩罚项,来防止过拟合。
  6. nesterov (bool, optional):是否使用Nesterov动量 (default: False)。Nesterov动量是标准动量方法的一个变种,它在计算梯度时使用了未来的近似位置。

用例

下面是一个使用SGD优化器的简单例子:

import torch  
import torch.nn as nn  
import torch.optim as optim  # 定义一个简单的模型  
model = nn.Sequential(  nn.Linear(10, 5),  nn.ReLU(),  nn.Linear(5, 2),  
)  # 定义损失函数  
criterion = nn.CrossEntropyLoss()  # 定义优化器  
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)  # 假设有输入数据和目标  
input_data = torch.randn(1, 10)  
target = torch.tensor([1])  # 训练循环(这里只展示了一次迭代)  
for epoch in range(1):  # 通常会有多个 epochs  # 前向传播  output = model(input_data)  # 计算损失  loss = criterion(output, target)  # 反向传播  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 计算当前梯度  # 更新参数  optimizer.step()  # 应用梯度更新  # 打印损失  print(f'Epoch {epoch+1}, Loss: {loss.item()}')

在这个例子中,我们创建了一个简单的两层神经网络模型,并使用SGD优化器来更新模型的参数。在训练循环中,我们执行了前向传播来计算模型的输出,然后计算了损失,通过调用loss.backward()执行了反向传播来计算梯度,最后通过调用optimizer.step()更新了模型的参数。在每次迭代开始时,我们使用optimizer.zero_grad()来清除之前累积的梯度,这是非常重要的步骤,因为PyTorch默认会累积梯度。

相关文章:

Pytorch-SGD算法解析

关注B站可以观看更多实战教学视频:肆十二-的个人空间-肆十二-个人主页-哔哩哔哩视频 (bilibili.com) SGD,即随机梯度下降(Stochastic Gradient Descent),是机器学习中用于优化目标函数的迭代方法,特别是在处…...

物联网土壤传感器简介

物联网土壤传感器简介 物联网土壤传感器的工作原理基于多种物理、化学和生物原理,通过感应器等组成部件将土壤中的特征数据转化为电信号,从而进行采集、处理和输出。这些传感器主要包括土壤湿度传感器、土壤温度传感器、土壤酸碱度传感器和土壤颗粒物传…...

MySQL索引面试题(高频)

文章目录 前言什么时候需要(不需要))使用索引?有哪些优化索引的方法前缀索引优化索引覆盖优化索引失效场景 总结 前言 今天来讲一讲 MySQL 索引的高频面试题。主要是针对前一篇文章 MySQL索引入门(一文搞定)进行查漏补…...

SouthLeetCode-打卡24年02月第2周

SouthLeetCode-打卡24年02月第2周 // Date : 2024/02/05 ~ 2024/02/11 039.有效的字母异位词 (1) 题目描述 039#LeetCode.242.简单题目链接#Monday2024/02/05 给定两个字符串 *s* 和 *t* ,编写一个函数来判断 *t* 是否是 *s* 的字母异位词。 **注意&#xff1…...

Rust CallBack的几种写法

模拟常用的几种函数调用CallBack的写法。测试调用都放在函数t6_call_back_task中。我正在学习Rust&#xff0c;有不对或者欠缺的地方&#xff0c;欢迎交流指正 type Callback std::sync::Arc<dyn Fn() Send Sync>; type CallbackReturnVal std::sync::Arc<dyn Fn…...

Redis突现拒绝连接问题处理总结

一、问题回顾 项目突然报异常 [INFO] 2024-02-20 10:09:43.116 i.l.core.protocol.ConnectionWatchdog [171]: Reconnecting, last destination was 192.168.0.231:6379 [WARN] 2024-02-20 10:09:43.120 i.l.core.protocol.ConnectionWatchdog [151]: Cannot reconnect…...

css中选择器的优先级

CSS 的优先级是由选择器的特指度&#xff08;Specificity&#xff09;和重要性&#xff08;Importance&#xff09;决定的&#xff0c;以下是优先级规则&#xff1a; 特指度&#xff1a; ID 选择器 (#id): 每个ID选择器计为100。 类选择器 (.class)、属性选择器 ([attr]) 和伪…...

python3字符串内建方法split()心得

python3字符串内建方法split()心得 概念 用指定分隔符&#xff08;默认是任何空白字符&#xff09;将字符串拆分成列表。 语法 string.split(separator.max) 参数1.split(参数2&#xff0c;参数3) 参数1&#xff1a;string 字符串&#xff0c;需要被拆分的字符串。 参数2&a…...

html的列表标签

列表标签 列表在html里面经常会用到的&#xff0c;主要使用来布局的&#xff0c;使其整齐好看. 无序列表 无序列表[重要]&#xff1a; ul &#xff0c;li 示例代码1&#xff1a; 对应的效果&#xff1a; 无序列表的属性 属性值描述typedisc&#xff0c;square&#xff0c;…...

【Pytorch深度学习开发实践学习】B站刘二大人课程笔记整理lecture04反向传播

lecture04反向传播 课程网址 Pytorch深度学习实践 部分课件内容&#xff1a; import torchx_data [1.0,2.0,3.0] y_data [2.0,4.0,6.0] w torch.tensor([1.0]) w.requires_grad Truedef forward(x):return x*wdef loss(x,y):y_pred forward(x)return (y_pred-y)**2…...

PyTorch使用Tricks:学习率衰减 !!

文章目录 前言 1、指数衰减 2、固定步长衰减 3、多步长衰减 4、余弦退火衰减 5、自适应学习率衰减 6、自定义函数实现学习率调整&#xff1a;不同层不同的学习率 前言 在训练神经网络时&#xff0c;如果学习率过大&#xff0c;优化算法可能会在最优解附近震荡而无法收敛&#x…...

10MARL深度强化学习 Value Decomposition in Common-Reward Games

文章目录 前言1、价值分解的研究现状2、Individual-Global-Max Property3、Linear and Monotonic Value Decomposition3.1线性值分解3.2 单调值分解 前言 中心化价值函数能够缓解一些多智能体强化学习当中的问题&#xff0c;如非平稳性、局部可观测、信用分配与均衡选择等问题…...

2 Nacos适配达梦数据库实现方案

1、修改源代码方式 Nacos 原生是不支持达梦数据库的,所以就要想办法让它 “支持”,因为是开源软件,我们可以从源码入手,在流行的 1.x 、2.x 或最新版本代码的基本上进行修改。 主要涉及到以下内容的修改: com/alibaba/nacos/persistence/datasource/ExternalDataS...

【Gitea】配置 Push To Create

引 在 Git 代码管理工具使用过程中&#xff0c;经常需要将一个文件夹作为仓库上传到一个未创建的代码仓库。如果 Git 服务端使用的是 Gitea&#xff0c;通常会推送失败。 PS D:\tmp\git-test> git remote add origin http://192.1.1.1:3000/root/git-test.git PS D:\tmp\g…...

关于postgresql数据库单独设置某个用户日志级别(日志审计)

前言&#xff1a; 很多时候我们想让数据库日志打印详细一点&#xff0c;但是又担心会对数据库本身产生一些不可控的影响&#xff0c;还会担心数据库产生的庞大的日志导致主机资源不太够用的影响。那么今天我们就通过讲解给单个用户设置 log_statement来解决以上这些问题。 注…...

阿里云ECS香港服务器性能强大、cn2高速网络租用价格表

阿里云香港服务器中国香港数据中心网络线路类型BGP多线精品&#xff0c;中国电信CN2高速网络高质量、大规格BGP带宽&#xff0c;运营商精品公网直连中国内地&#xff0c;时延更低&#xff0c;优化海外回中国内地流量的公网线路&#xff0c;可以提高国际业务访问质量。阿里云服务…...

实战打靶集锦-025-HackInOS

文章目录 1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查5. 提权5.1 枚举系统信息5.2 探索一下passwd5.3 枚举可执行文件5.4 查看capabilities位5.5 目录探索5.6 枚举定时任务5.7 Linpeas提权 靶机地址&#xff1a;https://download.vulnhub.com/hackinos/HackInOS.ova 1. 主机…...

list.stream().forEach()和list.forEach()的区别

list.stream().forEach() 和 list.forEach() 在 Java 中都是用于遍历集合元素的方法&#xff0c;但它们在使用场景和功能上有所不同&#xff1a; list.forEach()&#xff1a; 是从 Java 8 开始引入到 java.util.List 接口的标准方法。直接对列表进行迭代&#xff0c;它采用内部…...

JS基础之JSON对象

JS基础之JSON对象 目录 JS基础之JSON对象对象转JSON字符串JSON转JS对象 对象转JSON字符串 JSON.stringify(value,replacer,space) value:要转换的JS对象 replacer:(可选)用于过滤和转换结果的函数或数组 space:(可选)指定缩进量 // 创建JS对象 let date {name:"张三…...

嵌入式学习之Linux入门篇——使用VMware创建Unbuntu虚拟机

目录 主机硬件要求 VMware 安装 安装Unbuntu 18.04.6 LTS 新建虚拟机 进入Unbuntu安装环节 主机硬件要求 内存最少16G 硬盘最好分出一个单独的盘&#xff0c;而且最少预留200G&#xff0c;可以使用移动固态操作系统win7/10/11 VMware 安装 版本&#xff1a;VMware Works…...

大模型中的token是什么?

定义 大模型的"token"是指在自然语言处理&#xff08;NLP&#xff09;任务中&#xff0c;模型所使用的输入数据的最小单元。这些token可以是单词、子词或字符等&#xff0c;具体取决于模型的设计和训练方式。 大模型的token可以是单词级别的&#xff0c;也可以是子…...

跳表是一种什么样的数据结构

跳表是有序集合的底层数据结构&#xff0c;它其实是链表的一种进化体。正常链表是一个接着一个用指针连起来的&#xff0c;但这样查找效率低只有O(n)&#xff0c;为了解决这个问题&#xff0c;提出了跳表&#xff0c;实际上就是增加了高级索引。朴素的跳表指针是单向的并且元素…...

【刷题记录】最大公因数,最小公倍数(辗转相除法、欧几里得算法)

本系列博客为个人刷题思路分享&#xff0c;有需要借鉴即可。 1.题目链接&#xff1a; 无 2.详解思路&#xff1a; 题目描述&#xff1a;输入两个正整数&#xff0c;输出其最大公因数和最小公倍数 一般方法&#xff1a;最大公因数&#xff1a;穷加法&#xff1b;最小公倍数&…...

ETL快速拉取物流信息

我国作为世界第一的物流大国&#xff0c;但是在目前的物流信息系统还存在着几大的痛点。主要包括以下几个方面&#xff1a; 数据孤岛&#xff1a;有些物流企业各个部门之间的数据标准不一致&#xff0c;难以实现数据共享和协同&#xff0c;容易导致信息孤岛。 操作繁琐&#x…...

17.1 SpringMVC框架_SpringMVC入门与数据绑定(❤❤)

17.1 SpringMVC框架_SpringMVC入门与数据绑定 1. SpringMVC入门1.1 MVC介绍1.2 环境配置1. 依赖引入2. web配置文件:DispatchServlet配置3. applicationContext.xml配置4. 开发Controller控制器(❤❤)1.3 MVC处理流程图2. Spring MVC数据绑定2.1 URL Mapping2.2 URL Mapping三个…...

Leetcode 11.盛水最多的容器

题目 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说明&#xff1a;你不能倾斜容器。…...

《Go 简易速速上手小册》第7章:包管理与模块(2024 最新版)

文章目录 7.1 使用 Go Modules 管理依赖 - 掌舵向未来7.1.1 基础知识讲解7.1.2 重点案例:Web 服务功能描述实现步骤扩展功能7.1.3 拓展案例 1:使用数据库功能描述实现步骤扩展功能7.1.4 拓展案例 2:集成 Redis 缓存功能描述实现步骤...

【论文精读】IBOT

摘要 掩码语言建模(MLM)是一种流行的语言模型预训练范式&#xff0c;在nlp领域取得了巨大的成功。然而&#xff0c;它对视觉Transformer (ViT)的潜力尚未得到充分开发。为在视觉领域延续MLM的成功&#xff0c;故而探索掩码图像建模(MIM)&#xff0c;以训练更好的视觉transforme…...

Yolo V5在实时视频流中的建筑物与彩钢房检测:性能评估与改进方法

Yolo V5在实时视频流中的建筑物与彩钢房检测&#xff1a;性能评估与改进方法 文章目录 Yolo V5在实时视频流中的建筑物与彩钢房检测&#xff1a;性能评估与改进方法概述Yolo V5模型概述建筑物与彩钢房检测的挑战实时视频流处理流程模型性能评估改进方法实验与分析结论与展望 概…...

图——最小生成树实现(Kruskal算法,prime算法)

目录 预备知识&#xff1a; 最小生成树概念&#xff1a; Kruskal算法&#xff1a; 代码实现如下&#xff1a; 测试&#xff1a; Prime算法 &#xff1a; 代码实现如下&#xff1a; 测试&#xff1a; 结语&#xff1a; 预备知识&#xff1a; 连通图&#xff1a;在无向图…...