当前位置: 首页 > news >正文

【深度学习】pytorch——线性回归

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~

深度学习专栏链接:
http://t.csdnimg.cn/dscW7

pytorch——线性回归

  • 线性回归简介
  • 公式说明
  • 完整代码
  • 代码解释

线性回归简介

线性回归是一种用于建立特征和目标变量之间线性关系的统计学习方法。它假设特征和目标变量之间存在一个线性的关系,并试图通过拟合最佳的线性函数来预测目标变量。

线性回归模型的一般形式可以表示为:

y = w 0 + w 1 x 1 + w 2 x 2 + … + w n x n y = w_0 + w_1x_1 + w_2x_2 + \ldots + w_nx_n y=w0+w1x1+w2x2++wnxn

其中, y y y 是目标变量(或因变量), x 1 , x 2 , … , x n x_1, x_2, \ldots, x_n x1,x2,,xn 是特征变量(或自变量), w 0 , w 1 , w 2 , … , w n w_0, w_1, w_2, \ldots, w_n w0,w1,w2,,wn 是模型的参数,分别对应截距和各个特征的权重。

线性回归模型的训练过程就是寻找最优的参数 w 0 , w 1 , w 2 , … , w n w_0, w_1, w_2, \ldots, w_n w0,w1,w2,,wn 来使得模型的预测值与实际值之间的差异最小化。

公式说明

以下是代码涉及到的数学公式

  1. 线性回归模型

线性回归模型用于建立特征 x x x 和目标变量 y y y 之间的线性关系。在本代码中,线性回归模型被表示为:

y = w x + b y = wx + b y=wx+b

其中, w w w 是权重(即斜率), b b b 是偏置(即截距), x x x 是输入特征, y y y 是预测值。

  1. 损失函数

损失函数用于衡量模型预测值与实际标签之间的差异。在本代码中,使用的损失函数是均方误差(Mean Squared Error,MSE):

l o s s = 1 2 n ∑ i = 1 n ( y p r e d ( i ) − y ( i ) ) 2 loss = \frac{1}{2n} \sum_{i=1}^{n} (y_{pred}^{(i)} - y^{(i)})^2 loss=2n1i=1n(ypred(i)y(i))2

其中, y p r e d ( i ) y_{pred}^{(i)} ypred(i) 是模型的第 i i i 个样本的预测值, y ( i ) y^{(i)} y(i) 是实际标签, n n n 是样本数量。

  1. 其他运算

代码中还涉及到了矩阵乘法、矩阵转置、元素级别的操作等。例如, x . m m ( w ) x.mm(w) x.mm(w) 表示将输入特征 x x x 与权重 w w w 进行矩阵乘法; x T . m m ( d y _ p r e d ) x^T.mm(dy\_pred) xT.mm(dy_pred) 表示将输入特征 x x x 的转置与梯度 d y _ p r e d dy\_pred dy_pred 进行矩阵乘法。

完整代码

import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import displaydevice = t.device('cpu') #如果你想用gpu,改成t.device('cuda:0')# 设置随机数种子,保证在不同电脑上运行时下面的输出一致
t.manual_seed(1000) def get_fake_data(batch_size=8):''' 产生随机数据:y=x*2+3,加上了一些噪声'''x = t.rand(batch_size, 1, device=device) * 5y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)return x, y'''
# 产生的x-y分布
x, y = get_fake_data(batch_size=100)
plt.scatter(x.squeeze().cpu().numpy(), y.squeeze().cpu().numpy())
'''# 随机初始化参数
w = t.rand(1, 1).to(device)
b = t.zeros(1, 1).to(device)lr =0.02 # 学习率for ii in range(500):x, y = get_fake_data(batch_size=4)# forward:计算lossy_pred = x.mm(w) + b.expand_as(y) loss = 0.5 * (y_pred - y) ** 2 # 均方误差loss = loss.mean()# backward:手动计算梯度dloss = 1dy_pred = dloss * (y_pred - y)dw = x.t().mm(dy_pred)db = dy_pred.sum()# 更新参数w.sub_(lr * dw)b.sub_(lr * db)if ii%50 ==0:# 画图display.clear_output(wait=True)x = t.arange(0, 6).view(-1, 1)y = x.float().mm(w) + b.expand_as(x)plt.plot(x.cpu().numpy(), y.cpu().numpy(),color='b') # predictedx2, y2 = get_fake_data(batch_size=100) plt.scatter(x2.numpy(), y2.numpy(),color='r') # true dataplt.xlim(0, 5)plt.ylim(0, 15)plt.show()plt.pause(0.5)print('w: ', w.item(), 'b: ', b.item())

输出结果为:
在这里插入图片描述
w: 1.9709817171096802 b: 3.1699466705322266

代码解释

  1. 导入需要的库:
import torch as t
%matplotlib inline
from matplotlib import pyplot as plt
from IPython import display

导入PyTorch库以及绘图相关的库,%matplotlib inline是Jupyter Notebook中的魔法命令,用于在Notebook中显示绘图。

  1. 设置随机数种子:
t.manual_seed(1000)

这行代码设置随机数种子,保证每次运行结果的随机数生成过程一致。

  1. 定义生成随机数据的函数:
def get_fake_data(batch_size=8):''' 产生随机数据:y=x*2+3,加上了一些噪声'''x = t.rand(batch_size, 1, device=device) * 5y = x * 2 + 3 +  t.randn(batch_size, 1, device=device)return x, y

该函数用于产生随机的输入特征x和对应的标签y,其中y满足线性关系y = x * 2 + 3,并添加了一些随机噪声。

  1. 初始化模型参数:
w = t.rand(1, 1).to(device)
b = t.zeros(1, 1).to(device)

这里使用随机数初始化模型参数wb,并指定在CPU上进行计算。

  1. 设置学习率:
lr = 0.02

学习率lr控制每次参数更新的步长。

  1. 进行模型训练:
for ii in range(500):# 生成随机数据x, y = get_fake_data(batch_size=4)# forward:计算损失y_pred = x.mm(w) + b.expand_as(y)loss = 0.5 * (y_pred - y) ** 2loss = loss.mean()# backward:手动计算梯度dloss = 1dy_pred = dloss * (y_pred - y)dw = x.t().mm(dy_pred)db = dy_pred.sum()# 更新参数w.sub_(lr * dw)b.sub_(lr * db)

这里使用一个循环进行模型的训练,每次迭代都包含以下步骤:

  • 生成随机数据;
  • 前向传播:计算预测值y_pred和损失函数loss
  • 反向传播:手动计算梯度dwdb
  • 更新参数:根据梯度和学习率更新参数wb
  1. 可视化模型训练过程:
if ii % 50 == 0:display.clear_output(wait=True)x = t.arange(0, 6).view(-1, 1)y = x.float().mm(w) + b.expand_as(x)plt.plot(x.cpu().numpy(), y.cpu().numpy(), color='b') # predicted linex2, y2 = get_fake_data(batch_size=100)plt.scatter(x2.numpy(), y2.numpy(), color='r') # true dataplt.xlim(0, 5)plt.ylim(0, 15)plt.show()plt.pause(0.5)

这部分代码用于可视化模型训练的过程,每50次迭代将当前参数下的预测结果以蓝色线条的形式绘制出来,并将随机生成的100个样本以红色散点图显示出来。

  1. 输出最终训练得到的参数:
print('w: ', w.item(), 'b: ', b.item())

输出训练得到的参数wb的值。

相关文章:

【深度学习】pytorch——线性回归

笔记为自我总结整理的学习笔记,若有错误欢迎指出哟~ 深度学习专栏链接: http://t.csdnimg.cn/dscW7 pytorch——线性回归 线性回归简介公式说明完整代码代码解释 线性回归简介 线性回归是一种用于建立特征和目标变量之间线性关系的统计学习方法。它假设…...

golang工程——中间件redis,单节点集群部署

单节点redis集群部署 部署redis 6.2.7版本 没资源,就用一台机子部 解压安装包 tar zxf redis-6.2.7.tar.gzcd redis-6.2.7编译安装 mkdir -p /var/local/redis-6.2.7/{data,conf,logs,pid}data:数据目录 conf:配置文件目录 logs&#xf…...

Lua基础

table 基本原理: table是一种特殊的容器,可以向数组一样按照索引存取,也能按照键值对存取。 local mytable {1,2,3} --相当于数组 local mytable {[1]1,[2]2,[3]3} --和上面等价 local mytable {1,2,3,[3] 4} --隐式赋值会覆盖掉显式赋…...

微信小程序之开发工具介绍

一、微信小程序开发工具下载 微信小程序开发工具下载可以参考这篇博客《微信小程序开发者工具下载-CSDN博客》 二、开发工具组成部分 如下图所示,开发者工具主要由菜单栏、工具栏、模拟器、编辑器和调试器 5 个部分组成。。 1、菜单栏 菜单栏中主要包括项目、文…...

【AUTOSAR】【以太网】DoIp

AUTOSAR专栏——总目录_嵌入式知行合一的博客-CSDN博客文章浏览阅读217次。本文主要汇总该专栏文章,以方便各位读者阅读。https://xianfan.blog.csdn.net/article/details/132072415 目录 一、概述 二、功能描述 2.1 Do...

游戏中UI的性能优化手段

UI方面有许多性能优化的技术或手段,以下是其中一些常见的例子: 惰性加载:对于长列表、大图等需要加载大量数据和资源的组件,可以采用惰性加载的方式,即在用户需要时再进行加载。这样可以减少初始加载时间和内存占用&am…...

Idea快速生成测试类

例如写写完一个功能类,需要对里面方法进行测试 在当前页面 按住CTRLSHFITT 选择你要生成的测试方法 点击OK,就会在test目录下在你对应包下生成对应测试类...

Java文件操作详解

CONTENTS 1. 文件和目录路径1.1 获取Path的片段1.2 获取Path信息1.3 添加或删除路径片段 2. 文件系统3. 查找文件4. 读写文件 1. 文件和目录路径 Path 对象代表的是一个文件或目录的路径,它是在不同的操作系统和文件系统之上的抽象。它的目的是,在构建路…...

二叉树系列主题Code

Python实现二叉树遍历 # 定义二叉树节点类 class TreeNode: def __init__(self, val0, leftNone, rightNone): self.val val self.left left self.right right # 前序遍历(非递归) def preorderTraversal(root): if not root: return [] …...

Leetcode 673. 最长递增子序列的个数 C++

673最长递增子序列的个数 给定一个未排序的整数数组 nums , 返回最长递增子序列的个数 。 注意 这个数列必须是 严格 递增的。 示例 1: 输入: [1,3,5,4,7] 输出: 2 解释: 有两个最长递增子序列,分别是 [1, 3, 4, 7] 和[1, 3, 5, 7]。 示例 2: 输入: …...

html用css grid实现自适应四宫格放视频

想同时播放四个本地视频: 四宫格;自式应,即放缩浏览器时,四宫格也跟着放缩;尽量填满页面(F11 浏览器全屏时可以填满整个屏幕)。 在 html 中放视频用 video 标签,参考 [1]&#xff1…...

【机器学习可解释性】5.SHAP值的高级使用

机器学习可解释性 1.模型洞察的价值2.特征重要性排列3.部分依赖图4.SHAP 值5.SHAP值的高级使用 正文 汇总SHAP值以获得更详细的模型解释 总体回顾 我们从学习排列重要性和部分依赖图开始,以显示学习后的模型的内容。 然后我们学习了SHAP值来分解单个预测的组成部…...

CentOS开机自动运行jar程序实现

前面已经有一篇文章介绍jar包如何在CentOS上运行,《在linux上运行jar程序操作记录》 后来发现系统重启后不能自动运行,导致每次都要手动打开,这篇介绍如何自动开机启动运行jar程序。 一、找到JDK程序执行位置 [rootlocalhost /]# which jav…...

matlab双目标定中基线物理长度获取

在MATLAB进行双目摄像机标定时,通常会获得相机的内参,其中包括像素单位的焦距(focal length)以及物理单位的基线长度(baseline)。对于应用中的深度估计和测量,基线长度的物理单位非常重要,因为它直接影响到深度信息的准确性。有时候,您可能只能获取像素单位的焦距和棋…...

自己动手实现一个深度学习算法——二、神经网络的实现

文章目录 1. 神经网络概述1)表示2)激活函数3)sigmoid函数4)阶跃函数的实现5)sigmoid函数的实现6)sigmoid函数和阶跃函数的比较7)非线性函数8)ReLU函数 2.三层神经网络的实现1)结构2&…...

gRPC源码剖析-Builder模式

一、Builder模式 1、定义 将一个复杂对象的构建与表示分离,使得同样的构建过程可以创建不同的的表示。 2、适用场景 当创建复杂对象的算法应独立于该对象的组成部分以及它们的装配方式时。 当构造过程必须允许被构造的对象有不同的表示时。 说人话&#xff1a…...

ARM传输数据以及移位操作

3.2.2 数据传送指令 LDR/STR指令用来在寄存器和内存之间输送数据。如果我们想要在寄存器之间传送数据,则可以使用MOV指令。MOV指令的格式如下。 MOV {cond} {s}Rd, oprand2 MOV {cond} {s}Rd, oprand2 其中,{cond}为条件指令可选项,{s}用来表…...

06、如何将对象数组里 obj 的 key 值变成动态的(即:每一个对象对应的 key 值都不同)

1、数据情况: 其一、从后端拿到的数据为: let arr [1,3,6,10,11,23,24] 其二、目标数据为: [{vlan_1: 1, value: 1}, {vlan_3: 3, value: 1}, {vlan_6: 6, value: 1}, {vlan_10: 10, value: 1}, {vlan_11: 11, value: 1}, {vlan_23: 23, v…...

ngx_http_request_s

/* 罗剑锋老师的注释参考: https://github.com/chronolaw/annotated_nginx/blob/master/nginx/src/http/ngx_http_request.h */struct ngx_http_request_s {uint32_t signature; /* "HTTP" */ngx_connection_t …...

Docker 学习路线 2:底层技术

了解驱动Docker的核心技术将让您更深入地了解Docker的工作原理,并有助于您更有效地使用该平台。 Linux容器(LXC) Linux容器(LXC)是Docker的基础。 LXC是一种轻量级的虚拟化解决方案,允许多个隔离的Linux系…...

利用最小二乘法找圆心和半径

#include <iostream> #include <vector> #include <cmath> #include <Eigen/Dense> // 需安装Eigen库用于矩阵运算 // 定义点结构 struct Point { double x, y; Point(double x_, double y_) : x(x_), y(y_) {} }; // 最小二乘法求圆心和半径 …...

微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】

微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来&#xff0c;Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...

前端导出带有合并单元格的列表

// 导出async function exportExcel(fileName "共识调整.xlsx") {// 所有数据const exportData await getAllMainData();// 表头内容let fitstTitleList [];const secondTitleList [];allColumns.value.forEach(column > {if (!column.children) {fitstTitleL…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块&#xff0c;它提供了一个轻量级的 HTTP 服务器实现&#xff0c;主要用于构建基于 HTTP 的应用程序和服务。 功能介绍&#xff1a; 主要功能 HTTP服务器功能&#xff1a; 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

Springcloud:Eureka 高可用集群搭建实战(服务注册与发现的底层原理与避坑指南)

引言&#xff1a;为什么 Eureka 依然是存量系统的核心&#xff1f; 尽管 Nacos 等新注册中心崛起&#xff0c;但金融、电力等保守行业仍有大量系统运行在 Eureka 上。理解其高可用设计与自我保护机制&#xff0c;是保障分布式系统稳定的必修课。本文将手把手带你搭建生产级 Eur…...

tree 树组件大数据卡顿问题优化

问题背景 项目中有用到树组件用来做文件目录&#xff0c;但是由于这个树组件的节点越来越多&#xff0c;导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多&#xff0c;导致的浏览器卡顿&#xff0c;这里很明显就需要用到虚拟列表的技术&…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

DingDing机器人群消息推送

文章目录 1 新建机器人2 API文档说明3 代码编写 1 新建机器人 点击群设置 下滑到群管理的机器人&#xff0c;点击进入 添加机器人 选择自定义Webhook服务 点击添加 设置安全设置&#xff0c;详见说明文档 成功后&#xff0c;记录Webhook 2 API文档说明 点击设置说明 查看自…...

Golang——6、指针和结构体

指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

在树莓派上添加音频输入设备的几种方法

在树莓派上添加音频输入设备可以通过以下步骤完成&#xff0c;具体方法取决于设备类型&#xff08;如USB麦克风、3.5mm接口麦克风或HDMI音频输入&#xff09;。以下是详细指南&#xff1a; 1. 连接音频输入设备 USB麦克风/声卡&#xff1a;直接插入树莓派的USB接口。3.5mm麦克…...