当前位置: 首页 > news >正文

线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】

文章目录

    • 3.1、线性回归
      • 3.1.1、PyTorch 从零实现线性回归
      • 3.1.2、简单实现线性回归

在这里插入图片描述

3.1、线性回归

线性回归是显式解,深度学习中绝大多数遇到的都是隐式解。

3.1.1、PyTorch 从零实现线性回归

%matplotlib inline
import random
import torch
#d2l库中的torch模块,并将其用别名d2l引用。d2l库是《动手学深度学习》(Dive into Deep Learning)这本书的配套库,包含了一些自定义的函数和工具,以及对PyTorch库的包装和扩展。
from d2l import torch as d2l

生成数据集及标签

def synthetic_data(w,b,num_examples):"""生成 y = Xw + b + 噪声"""X = torch.normal(0,1,(num_examples,len(w)))#创建一个大小为(num_examples, len(w))的张量X,并使用均值为0,标准差为1的正态分布对其进行初始化。这个张量代表输入特征,其中 num_examples 是样本数量,len(w) 是特征向量的长度。y = torch.matmul(X,w) + by += torch.normal(0, 0.01, y.shape)#预测值y中添加一个均值为0,标准差为0.01的正态分布噪声,以增加模型的随机性和泛化能力。return X, y.reshape((-1,1))#预测值y通过reshape方法被转换成一个列向量
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = synthetic_data(true_w,true_b,1000)print('features:',features[0],'\nlabel:',labels[0])d2l.set_figsize()#设置图表尺寸
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1);           

d2l.plt.scatter(,,),使用d2l库中的绘图函数来创建散点图。

这个函数接受三个参数:

  • features[:,1].detach().numpy() 是一个二维张量features的切片操作,选择了所有行的第二列数据。detach()函数用于将张量从计算图中分离,numpy()方法将张量转换为NumPy数组。这样得到的是一个NumPy数组,代表散点图中的x轴数据。

  • labels.detach().numpy() 是一个二维张量labels的分离和转换操作,得到一个NumPy数组,代表散点图中的y轴数据。

  • 1 是可选参数,用于设置散点的标记尺寸。在这里,设置为1表示每个散点的大小为1个点。

这里为什么要用detach()?

尝试去掉后结果是不变的,应对某些pytorch版本转numpy必须这样做。

def data_iter(batch_size, features, labels):num_examples = len(features)#创建一个包含0到num_examples-1的整数列表,表示样本索引。indices = list(range(num_examples))#随机打乱样本索引的顺序,样本是随机读取的,没有特定顺序。random.shuffle(indices)for i in range(0, num_examples, batch_size):# 根据当前批次的起始索引,创建一个包含当前批次样本索引的张量。min(i + batch_size, num_examples)确保最后一个批次的大小不超过剩余样本数量。batch_indices = torch.tensor(indices[i:min(i + batch_size,num_examples)])# 使用生成器返回当前批次的特征和标签。yield features[batch_indices],labels[batch_indices]batch_size = 10for X,y in data_iter(batch_size, features, labels):print(X,'\n',y)break

小插叙,synthetic_data()返回值中X敲成了小写,直接导致后面矩阵乘法形状对不上,找了半天错误。

yield 预备知识:

当一个函数包含 yield 语句时,它就变成了一个生成器函数。生成器函数用于生成一个序列的值,而不是一次性返回所有值。每次调用生成器函数时,它会暂停执行,并返回一个值。当下一次调用生成器函数时,它会从上次暂停的地方继续执行,直到遇到下一个 yield 语句或函数结束。

定义初始化模型参数

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1,  requires_grad=True)

定义模型

def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b  #X, w进行矩阵乘法

定义损失函数

def squared_loss(y_hat,y): #(预测值,真实值)"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) **2 / 2 

这就是数据是张量的好处,

M S E ( y , y ′ ) = ∑ i = 1 n ( y i − y i ′ ) 2 n MSE(y,y') = \frac{\sum_{i=1}^n(y_i-y_i')^2}{n} MSE(y,y)=ni=1n(yiyi)2

明明是含有求和操作的数学公式,在张量面前形同虚设,代码实现是这么简单。就像在写标量公式一样。

定义优化算法

def sgd(params, lr, batch_size):#一个包含待更新参数的列表,学习率,每个小批次中的样本数量)"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -=lr * param.grad / batch_sizeparam.grad.zero_()

为什么执行的减法而不是加法?

梯度的负方向

优化算法是怎么跟损失函数合作来完成参数优化?

优化函数没有直接使用损失值,但通过使用损失函数和反向传播计算参数的梯度,并将这些梯度应用于参数更新,间接地优化了模型的损失。梯度下降算法利用了参数的梯度信息来更新参数,以使损失函数尽可能减小。

优化算法(例如随机梯度下降)是怎么拿到损失函数的梯度信息的?

损失函数梯度完整的说是 loss关于x,loss关于y的梯度 ,搞清楚这个概念就不难理解了【初学时,我误解成了损失值的梯度x,y的梯度是两个概念,显然后者是非常不准确的表述】,损失函数梯度就在 sgd的params中。

l = loss(net(X, w, b), y) 
l.sum().backward()#此时损失函数梯度【关于w,b的梯度】存在w.grad,b.grad中
sgd([w,b], lr, batch_size) #使用参数梯度更新参数

param.grad.zero_()在这里有什么意义?谁会干扰梯度的求解?

如果在循环的下一次迭代中不使用param.grad.zero_()来清零参数的梯度,那么参数将会保留上一次迭代计算得到的梯度值,继续沿用该梯度值来求解梯度。就是说上次for循环的param会对下次param的梯度求解产生影响,所以才要清空梯度。

训练过程

#超参数
lr =0.03 #学习率(learning rate),控制每次参数更新的步幅大小。
num_epochs = 3 #数据集的扫描次数,即要重复训练模型的次数。
net =linreg #表示模型,这里使用了一个名为linreg的线性回归模型。
loss = squared_loss#表示损失函数,这里使用了一个名为squared_loss的均方损失函数。for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y) # X 、y 的小批量损失#l形状是 (batch_size, 1),非标量l.sum().backward()sgd([w,b], lr, batch_size) #使用参数梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b),labels)print('epoch ',epoch+1,'loss ',float(train_l.mean()))
epoch  1 loss  0.032808780670166016
epoch  2 loss  0.00011459046800155193
epoch  3 loss  5.012870315113105e-05

with torch.no_grad()有什么意义,毕竟没有backward()操作?

对于with torch.no_grad()块,在 PyTorch 中禁用梯度追踪和计算图的构建。在该块中执行的操作不会被记录到计算图中,因此不会生成梯度信息。其作用是告诉 PyTorch 不要跟踪计算梯度,这样可以节省计算资源。

简单说,就是计算损失值的张量运算不会记录到计算图中,因为没必要,而且不建立计算图,求损失值更快了。

代码存在的小问题

最后一批次可能不足batch_size,sgd 执行 param -=lr * param.grad / batch_size取平均是有问题的,修改后:

sgd([w,b], lr,min(batch_size, X.shape[0])) #使用参数梯度更新参数

比较真实参数与训练学到的参数评估训练成功程度

print('w的估计误差:',true_w - w.reshape(true_w.shape))
print('b的估计误差:',true_b - b)
w的估计误差: tensor([-6.1035e-05,  2.5797e-04], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0018], grad_fn=<RsubBackward1>)

3.1.2、简单实现线性回归

生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2,-3.4])
true_b = 4.2
#d2l 的人造数据集函数
features, labels = d2l.synthetic_data(true_w, true_b,1000)

读取数据集

def load_array(data_arrays, batch_size, is_train=True):"""构造一个Pytorch数据迭代器"""#PyTorch提供的一个用于封装多个张量数据的数据集对象,*data_arrays用于将数据数组解包为多个参数。#*data_arrays解包等价于 dataset = data.TensorDataset(features, labels)dataset = data.TensorDataset(*data_arrays)#PyTorch提供的一个用于批量加载数据的迭代器return data.DataLoader(dataset, batch_size, shuffle= is_train)batch_size = 10
data_iter = load_array((features,labels), batch_size)
#iter() 函数将数据迭代器转换为迭代器对象,而 next() 函数用于获取迭代器的下一个元素。
next(iter(data_iter))

解包操作(见 python 预备知识)

星号 *dataset = data.TensorDataset(*data_arrays) 中的作用是将元组或列表中的元素解包,并作为独立的参数传递给函数或构造函数。这样可以更方便地传递多个参数。

迭代器使用(见 python 预备知识)

iter() 函数的主要目的是将可迭代对象转换为迭代器对象,以便于使用 next() 函数逐个访问其中的元素。

使用框架预定好的层

from torch import nn
#线性回归就是一个简单的单层神经网络
#一个全连接层,它接受大小为 2 的输入特征,并输出大小为 1 的特征。
net = nn.Sequential(nn.Linear(2,1))

初始化模型参数

#net[0] 表示模型中的第一个层,weight权重参数,正态分布初始化
net[0].weight.data.normal_(0,0.01)
#第一层加入偏差
net[0].bias.data.fill_(0)

实例化损失函数

loss = nn.MSELoss()

实例化优化算法( SGD)

#net.parameters() 返回一个迭代器,该迭代器包含了模型中所有可训练的参数。
trainer = torch.optim.SGD(net.parameters(),lr=0.03)

训练过程

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features),labels)print(f'epoch {epoch+1}, loss {l:f}')

关于输出格式,最后一个明显最好

print('epoch ',epoch+1,',loss ',l)       #epoch  1 ,loss  tensor(9.9119e-05, grad_fn=<MseLossBackward0>)print('epoch ',epoch+1,',loss ',float(l))# epoch  1 ,loss  9.872819646261632e-05print(f'epoch {epoch+1}, loss {l:f}')    # epoch 1, loss 0.000099

还可以自定义保留有限小数位

print(f'epoch {epoch+1}, loss {l:.4f}')# 保留4位。

相关文章:

线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】

文章目录 3.1、线性回归3.1.1、PyTorch 从零实现线性回归3.1.2、简单实现线性回归 3.1、线性回归 线性回归是显式解&#xff0c;深度学习中绝大多数遇到的都是隐式解。 3.1.1、PyTorch 从零实现线性回归 %matplotlib inline import random import torch #d2l库中的torch模块&a…...

js实现多种按钮

你可以使用JavaScript来实现多种类型的按钮&#xff0c;以下是几个常见的示例&#xff1a; 普通按钮&#xff08;Normal Button&#xff09;&#xff1a; <button>Click me</button> 带图标的按钮&#xff08;Button with Icon&#xff09;&#xff1a; <bu…...

getopt函数(未更新完)

2023年7月28日&#xff0c;周五上午 这是我目前碰到过的比较复杂的函数之一&#xff0c; 为了彻底弄懂这个函数&#xff0c;我花了几个小时。 为了更好的说明这个函数&#xff0c;之后我可能会录制讲解视频并上传到B站&#xff0c; 如果我上传到B站&#xff0c;我会在文章添…...

SpringCloud学习路线(9)——服务异步通讯RabbitMQ

一、初见MQ &#xff08;一&#xff09;什么是MQ&#xff1f; MQ&#xff08;MessageQueue&#xff09;&#xff0c;意思是消息队列&#xff0c;也就是事件驱动架构中的Broker。 &#xff08;二&#xff09;同步调用 1、概念&#xff1a; 同步调用是指&#xff0c;某一服务…...

postcss-pxtorem适配插件动态配置rootValue(根据文件路径名称,动态改变vue.config里配置的值)

项目背景&#xff1a;一个项目里有两个分辨率的设计稿(1920和2400)&#xff0c;不能拆开来打包 参考&#xff1a; 是参考vant插件&#xff1a;移动端Vant组件库rem适配下大小异常的解决方案&#xff1a;https://github.com/youzan/vant/issues/1181 说明&#xff1a; 因为vue.c…...

代码随想录算法训练营第二十三天 | 额外题目系列

额外题目 1365. 有多少小于当前数字的数字借着本题&#xff0c;学习一下各种排序未看解答自己编写的青春版重点代码随想录的代码我的代码(当天晚上理解后自己编写) 941.有效的山脉数组未看解答自己编写的青春版重点代码随想录的代码我的代码(当天晚上理解后自己编写) 1207. 独一…...

UiAutomator

运行Espresso和UI Automator测试时要使用模拟器。国内手机的ROM大多进行过修改&#xff0c;可能加入很多限制&#xff0c;导致测试无法正常运行。 Espresso只支持一个活动内部交互行为的测试。跨越多个活动、多个应用的场景需要使用UI Automator。使用Espresso和UI Automator的…...

stm32标准库开发常用函数的使用和代码说明

文章目录 GPIO&#xff08;General Purpose Input/Output&#xff09;NVIC&#xff08;Nested Vectored Interrupt Controller&#xff09;DMA&#xff08;Direct Memory Access&#xff09;USART&#xff08;Universal Synchronous/Asynchronous Receiver/Transmitter&#xf…...

有关合泰BA45F5260中断的思考

最近看前辈写的代码&#xff0c;发现这样一段代码&#xff1a; #ifdef SUPPORT_RF_NET_FUNCTION if(UART_INT_is_L()) { TmrInsertTimer(eTmrHdlUartRxDelay,TMR_PERIOD(2000),NULL); break; } #endif 其中UART_INT_is_L&am…...

Numpy-算数函数与数学函数

⛳算数函数 如果参与运算的两个对象都是ndarray&#xff0c;并且形状相同&#xff0c;那么会对位彼此之间进 第 30 页 行&#xff08; - * /&#xff09;运算。NumPy 算术函数包含简单的加减乘除: add()&#xff0c;subtract()&#xff0c;multiply() 和divide()。 &#x1f…...

Nginx在springboot中起到的作用

面试时这样回答&#xff1a; 在Spring Boot项目中使用Nginx可以有以下用途&#xff1a; 1. 反向代理&#xff1a;Nginx可以作为反向代理服务器&#xff0c;将外部请求转发到后端的Spring Boot应用&#xff0c;并可以实现负载均衡、高可用、缓存等功能&#xff0c;提高系统的性…...

12.(开发工具篇vscode+git)vscode 不能识别npm命令

1&#xff1a;vscode 不能识别npm命令 问题描述&#xff1a; 解决方式&#xff1a; &#xff08;1&#xff09;右击VSCode图标&#xff0c;选择以管理员身份运行&#xff1b; &#xff08;2&#xff09;在终端中执行get-ExecutionPolicy&#xff0c;显示Restricted&#xff…...

如何在MacBook上彻底删除mysql

好久以前安装过&#xff0c;但是现在配置mysql一直出错&#xff0c;索性全部删掉重新配置。 一、停止MySQL服务 首先&#xff0c;请确保 MySQL 服务器已经停止运行&#xff0c;以免影响后续的删除操作。 sudo /usr/local/mysql/support-files/mysql.server stop如果你输入之…...

web攻击面试|网络渗透面试(一)

Web攻击面试大纲 常见Web攻击类型 1.1 SQL注入攻击 1.2 XSS攻击 1.3 CSRF攻击 1.4 命令注入攻击SQL注入攻击 2.1 基本概念 2.2 攻击原理 2.3 防御措施XSS攻击 3.1 基本概念 3.2 攻击原理 3.3 防御措施CSRF攻击 4.1 基本概念 4.2 攻击原理 4.3 防御措施命令注入攻击 5.1 基本概…...

VBA操作WORD(六)另存为不含宏的文档

Sub 另存为不含宏的文档()Application.DisplayAlerts False Application.ScreenUpdating FalseDim oDoc As DocumentSet oDoc Word.ActiveDocumentDim oRng As RangeSet oRng oDoc.ContentDim sPath As String默认存储路径&#xff0c;当前用户桌面&#xff0c;注释掉的是当…...

分享69个Java源码,总有一款适合您

Java源码 分享69个Java源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c;大家下载后可以看到。 源码下载链接&#xff1a; https://pan.baidu.com/s/1ZgbJhMNwIyFyqFzHsDdL5w 提取码&a…...

《cool! autodistill帮你标注数据训练yolov8模型》学习笔记

《cool! autodistill帮你标注数据训练yolov8模型》 Summary Autodistill是一个用于自动标注数据训练边缘模型的工具。 Highlights &#x1f4a1; Autodistill由Robotflow推出&#xff0c;用于训练建立部署计算机视觉模型。&#x1f4bb; 通过使用大模型自动标注和训练小模型…...

Rust vs Go:常用语法对比(十)

题图来自 Rust vs. Golang: Which One is Better?[1] 182. Quine program Output the source of the program. 输出程序的源代码 package mainimport "fmt"func main() { fmt.Printf("%s%c%s%c\n", s, 0x60, s, 0x60)}var s package mainimport "fm…...

SliverPersistentHeader组件 实现Flutter吸顶效果

效果&#xff1a; 20230723-212152-73_Trim 代码&#xff1a; import package:flutter/cupertino.dart; import package:flutter/material.dart;class StickHeaderPage extends StatefulWidget {overrideState<StatefulWidget> createState() {// TODO: implement creat…...

Nginx性能优化配置

一、全局优化 # 工作进程数 worker_processes auto; # 建议 CPU核心数|CPU线程数# 最大支持的连接(open-file)数量&#xff1b;最大值受限于 Linux open files (ulimit -n) # 建议公式&#xff1a;worker_rlimit_nofile > worker_processes * worker_connections…...

OpenLayers 可视化之热力图

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 热力图&#xff08;Heatmap&#xff09;又叫热点图&#xff0c;是一种通过特殊高亮显示事物密度分布、变化趋势的数据可视化技术。采用颜色的深浅来显示…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

MFC内存泄露

1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...

TRS收益互换:跨境资本流动的金融创新工具与系统化解决方案

一、TRS收益互换的本质与业务逻辑 &#xff08;一&#xff09;概念解析 TRS&#xff08;Total Return Swap&#xff09;收益互换是一种金融衍生工具&#xff0c;指交易双方约定在未来一定期限内&#xff0c;基于特定资产或指数的表现进行现金流交换的协议。其核心特征包括&am…...

mysql已经安装,但是通过rpm -q 没有找mysql相关的已安装包

文章目录 现象&#xff1a;mysql已经安装&#xff0c;但是通过rpm -q 没有找mysql相关的已安装包遇到 rpm 命令找不到已经安装的 MySQL 包时&#xff0c;可能是因为以下几个原因&#xff1a;1.MySQL 不是通过 RPM 包安装的2.RPM 数据库损坏3.使用了不同的包名或路径4.使用其他包…...

tree 树组件大数据卡顿问题优化

问题背景 项目中有用到树组件用来做文件目录&#xff0c;但是由于这个树组件的节点越来越多&#xff0c;导致页面在滚动这个树组件的时候浏览器就很容易卡死。这种问题基本上都是因为dom节点太多&#xff0c;导致的浏览器卡顿&#xff0c;这里很明显就需要用到虚拟列表的技术&…...

Fabric V2.5 通用溯源系统——增加图片上传与下载功能

fabric-trace项目在发布一年后,部署量已突破1000次,为支持更多场景,现新增支持图片信息上链,本文对图片上传、下载功能代码进行梳理,包含智能合约、后端、前端部分。 一、智能合约修改 为了增加图片信息上链溯源,需要对底层数据结构进行修改,在此对智能合约中的农产品数…...

【电力电子】基于STM32F103C8T6单片机双极性SPWM逆变(硬件篇)

本项目是基于 STM32F103C8T6 微控制器的 SPWM(正弦脉宽调制)电源模块,能够生成可调频率和幅值的正弦波交流电源输出。该项目适用于逆变器、UPS电源、变频器等应用场景。 供电电源 输入电压采集 上图为本设计的电源电路,图中 D1 为二极管, 其目的是防止正负极电源反接, …...

【笔记】WSL 中 Rust 安装与测试完整记录

#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统&#xff1a;Ubuntu 24.04 LTS (WSL2)架构&#xff1a;x86_64 (GNU/Linux)Rust 版本&#xff1a;rustc 1.87.0 (2025-05-09)Cargo 版本&#xff1a;cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...

windows系统MySQL安装文档

概览&#xff1a;本文讨论了MySQL的安装、使用过程中涉及的解压、配置、初始化、注册服务、启动、修改密码、登录、退出以及卸载等相关内容&#xff0c;为学习者提供全面的操作指导。关键要点包括&#xff1a; 解压 &#xff1a;下载完成后解压压缩包&#xff0c;得到MySQL 8.…...