当前位置: 首页 > news >正文

线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】

文章目录

    • 3.1、线性回归
      • 3.1.1、PyTorch 从零实现线性回归
      • 3.1.2、简单实现线性回归

在这里插入图片描述

3.1、线性回归

线性回归是显式解,深度学习中绝大多数遇到的都是隐式解。

3.1.1、PyTorch 从零实现线性回归

%matplotlib inline
import random
import torch
#d2l库中的torch模块,并将其用别名d2l引用。d2l库是《动手学深度学习》(Dive into Deep Learning)这本书的配套库,包含了一些自定义的函数和工具,以及对PyTorch库的包装和扩展。
from d2l import torch as d2l

生成数据集及标签

def synthetic_data(w,b,num_examples):"""生成 y = Xw + b + 噪声"""X = torch.normal(0,1,(num_examples,len(w)))#创建一个大小为(num_examples, len(w))的张量X,并使用均值为0,标准差为1的正态分布对其进行初始化。这个张量代表输入特征,其中 num_examples 是样本数量,len(w) 是特征向量的长度。y = torch.matmul(X,w) + by += torch.normal(0, 0.01, y.shape)#预测值y中添加一个均值为0,标准差为0.01的正态分布噪声,以增加模型的随机性和泛化能力。return X, y.reshape((-1,1))#预测值y通过reshape方法被转换成一个列向量
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = synthetic_data(true_w,true_b,1000)print('features:',features[0],'\nlabel:',labels[0])d2l.set_figsize()#设置图表尺寸
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1);           

d2l.plt.scatter(,,),使用d2l库中的绘图函数来创建散点图。

这个函数接受三个参数:

  • features[:,1].detach().numpy() 是一个二维张量features的切片操作,选择了所有行的第二列数据。detach()函数用于将张量从计算图中分离,numpy()方法将张量转换为NumPy数组。这样得到的是一个NumPy数组,代表散点图中的x轴数据。

  • labels.detach().numpy() 是一个二维张量labels的分离和转换操作,得到一个NumPy数组,代表散点图中的y轴数据。

  • 1 是可选参数,用于设置散点的标记尺寸。在这里,设置为1表示每个散点的大小为1个点。

这里为什么要用detach()?

尝试去掉后结果是不变的,应对某些pytorch版本转numpy必须这样做。

def data_iter(batch_size, features, labels):num_examples = len(features)#创建一个包含0到num_examples-1的整数列表,表示样本索引。indices = list(range(num_examples))#随机打乱样本索引的顺序,样本是随机读取的,没有特定顺序。random.shuffle(indices)for i in range(0, num_examples, batch_size):# 根据当前批次的起始索引,创建一个包含当前批次样本索引的张量。min(i + batch_size, num_examples)确保最后一个批次的大小不超过剩余样本数量。batch_indices = torch.tensor(indices[i:min(i + batch_size,num_examples)])# 使用生成器返回当前批次的特征和标签。yield features[batch_indices],labels[batch_indices]batch_size = 10for X,y in data_iter(batch_size, features, labels):print(X,'\n',y)break

小插叙,synthetic_data()返回值中X敲成了小写,直接导致后面矩阵乘法形状对不上,找了半天错误。

yield 预备知识:

当一个函数包含 yield 语句时,它就变成了一个生成器函数。生成器函数用于生成一个序列的值,而不是一次性返回所有值。每次调用生成器函数时,它会暂停执行,并返回一个值。当下一次调用生成器函数时,它会从上次暂停的地方继续执行,直到遇到下一个 yield 语句或函数结束。

定义初始化模型参数

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1,  requires_grad=True)

定义模型

def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b  #X, w进行矩阵乘法

定义损失函数

def squared_loss(y_hat,y): #(预测值,真实值)"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) **2 / 2 

这就是数据是张量的好处,

M S E ( y , y ′ ) = ∑ i = 1 n ( y i − y i ′ ) 2 n MSE(y,y') = \frac{\sum_{i=1}^n(y_i-y_i')^2}{n} MSE(y,y)=ni=1n(yiyi)2

明明是含有求和操作的数学公式,在张量面前形同虚设,代码实现是这么简单。就像在写标量公式一样。

定义优化算法

def sgd(params, lr, batch_size):#一个包含待更新参数的列表,学习率,每个小批次中的样本数量)"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -=lr * param.grad / batch_sizeparam.grad.zero_()

为什么执行的减法而不是加法?

梯度的负方向

优化算法是怎么跟损失函数合作来完成参数优化?

优化函数没有直接使用损失值,但通过使用损失函数和反向传播计算参数的梯度,并将这些梯度应用于参数更新,间接地优化了模型的损失。梯度下降算法利用了参数的梯度信息来更新参数,以使损失函数尽可能减小。

优化算法(例如随机梯度下降)是怎么拿到损失函数的梯度信息的?

损失函数梯度完整的说是 loss关于x,loss关于y的梯度 ,搞清楚这个概念就不难理解了【初学时,我误解成了损失值的梯度x,y的梯度是两个概念,显然后者是非常不准确的表述】,损失函数梯度就在 sgd的params中。

l = loss(net(X, w, b), y) 
l.sum().backward()#此时损失函数梯度【关于w,b的梯度】存在w.grad,b.grad中
sgd([w,b], lr, batch_size) #使用参数梯度更新参数

param.grad.zero_()在这里有什么意义?谁会干扰梯度的求解?

如果在循环的下一次迭代中不使用param.grad.zero_()来清零参数的梯度,那么参数将会保留上一次迭代计算得到的梯度值,继续沿用该梯度值来求解梯度。就是说上次for循环的param会对下次param的梯度求解产生影响,所以才要清空梯度。

训练过程

#超参数
lr =0.03 #学习率(learning rate),控制每次参数更新的步幅大小。
num_epochs = 3 #数据集的扫描次数,即要重复训练模型的次数。
net =linreg #表示模型,这里使用了一个名为linreg的线性回归模型。
loss = squared_loss#表示损失函数,这里使用了一个名为squared_loss的均方损失函数。for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y) # X 、y 的小批量损失#l形状是 (batch_size, 1),非标量l.sum().backward()sgd([w,b], lr, batch_size) #使用参数梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b),labels)print('epoch ',epoch+1,'loss ',float(train_l.mean()))
epoch  1 loss  0.032808780670166016
epoch  2 loss  0.00011459046800155193
epoch  3 loss  5.012870315113105e-05

with torch.no_grad()有什么意义,毕竟没有backward()操作?

对于with torch.no_grad()块,在 PyTorch 中禁用梯度追踪和计算图的构建。在该块中执行的操作不会被记录到计算图中,因此不会生成梯度信息。其作用是告诉 PyTorch 不要跟踪计算梯度,这样可以节省计算资源。

简单说,就是计算损失值的张量运算不会记录到计算图中,因为没必要,而且不建立计算图,求损失值更快了。

代码存在的小问题

最后一批次可能不足batch_size,sgd 执行 param -=lr * param.grad / batch_size取平均是有问题的,修改后:

sgd([w,b], lr,min(batch_size, X.shape[0])) #使用参数梯度更新参数

比较真实参数与训练学到的参数评估训练成功程度

print('w的估计误差:',true_w - w.reshape(true_w.shape))
print('b的估计误差:',true_b - b)
w的估计误差: tensor([-6.1035e-05,  2.5797e-04], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0018], grad_fn=<RsubBackward1>)

3.1.2、简单实现线性回归

生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2,-3.4])
true_b = 4.2
#d2l 的人造数据集函数
features, labels = d2l.synthetic_data(true_w, true_b,1000)

读取数据集

def load_array(data_arrays, batch_size, is_train=True):"""构造一个Pytorch数据迭代器"""#PyTorch提供的一个用于封装多个张量数据的数据集对象,*data_arrays用于将数据数组解包为多个参数。#*data_arrays解包等价于 dataset = data.TensorDataset(features, labels)dataset = data.TensorDataset(*data_arrays)#PyTorch提供的一个用于批量加载数据的迭代器return data.DataLoader(dataset, batch_size, shuffle= is_train)batch_size = 10
data_iter = load_array((features,labels), batch_size)
#iter() 函数将数据迭代器转换为迭代器对象,而 next() 函数用于获取迭代器的下一个元素。
next(iter(data_iter))

解包操作(见 python 预备知识)

星号 *dataset = data.TensorDataset(*data_arrays) 中的作用是将元组或列表中的元素解包,并作为独立的参数传递给函数或构造函数。这样可以更方便地传递多个参数。

迭代器使用(见 python 预备知识)

iter() 函数的主要目的是将可迭代对象转换为迭代器对象,以便于使用 next() 函数逐个访问其中的元素。

使用框架预定好的层

from torch import nn
#线性回归就是一个简单的单层神经网络
#一个全连接层,它接受大小为 2 的输入特征,并输出大小为 1 的特征。
net = nn.Sequential(nn.Linear(2,1))

初始化模型参数

#net[0] 表示模型中的第一个层,weight权重参数,正态分布初始化
net[0].weight.data.normal_(0,0.01)
#第一层加入偏差
net[0].bias.data.fill_(0)

实例化损失函数

loss = nn.MSELoss()

实例化优化算法( SGD)

#net.parameters() 返回一个迭代器,该迭代器包含了模型中所有可训练的参数。
trainer = torch.optim.SGD(net.parameters(),lr=0.03)

训练过程

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features),labels)print(f'epoch {epoch+1}, loss {l:f}')

关于输出格式,最后一个明显最好

print('epoch ',epoch+1,',loss ',l)       #epoch  1 ,loss  tensor(9.9119e-05, grad_fn=<MseLossBackward0>)print('epoch ',epoch+1,',loss ',float(l))# epoch  1 ,loss  9.872819646261632e-05print(f'epoch {epoch+1}, loss {l:f}')    # epoch 1, loss 0.000099

还可以自定义保留有限小数位

print(f'epoch {epoch+1}, loss {l:.4f}')# 保留4位。

相关文章:

线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】

文章目录 3.1、线性回归3.1.1、PyTorch 从零实现线性回归3.1.2、简单实现线性回归 3.1、线性回归 线性回归是显式解&#xff0c;深度学习中绝大多数遇到的都是隐式解。 3.1.1、PyTorch 从零实现线性回归 %matplotlib inline import random import torch #d2l库中的torch模块&a…...

js实现多种按钮

你可以使用JavaScript来实现多种类型的按钮&#xff0c;以下是几个常见的示例&#xff1a; 普通按钮&#xff08;Normal Button&#xff09;&#xff1a; <button>Click me</button> 带图标的按钮&#xff08;Button with Icon&#xff09;&#xff1a; <bu…...

getopt函数(未更新完)

2023年7月28日&#xff0c;周五上午 这是我目前碰到过的比较复杂的函数之一&#xff0c; 为了彻底弄懂这个函数&#xff0c;我花了几个小时。 为了更好的说明这个函数&#xff0c;之后我可能会录制讲解视频并上传到B站&#xff0c; 如果我上传到B站&#xff0c;我会在文章添…...

SpringCloud学习路线(9)——服务异步通讯RabbitMQ

一、初见MQ &#xff08;一&#xff09;什么是MQ&#xff1f; MQ&#xff08;MessageQueue&#xff09;&#xff0c;意思是消息队列&#xff0c;也就是事件驱动架构中的Broker。 &#xff08;二&#xff09;同步调用 1、概念&#xff1a; 同步调用是指&#xff0c;某一服务…...

postcss-pxtorem适配插件动态配置rootValue(根据文件路径名称,动态改变vue.config里配置的值)

项目背景&#xff1a;一个项目里有两个分辨率的设计稿(1920和2400)&#xff0c;不能拆开来打包 参考&#xff1a; 是参考vant插件&#xff1a;移动端Vant组件库rem适配下大小异常的解决方案&#xff1a;https://github.com/youzan/vant/issues/1181 说明&#xff1a; 因为vue.c…...

代码随想录算法训练营第二十三天 | 额外题目系列

额外题目 1365. 有多少小于当前数字的数字借着本题&#xff0c;学习一下各种排序未看解答自己编写的青春版重点代码随想录的代码我的代码(当天晚上理解后自己编写) 941.有效的山脉数组未看解答自己编写的青春版重点代码随想录的代码我的代码(当天晚上理解后自己编写) 1207. 独一…...

UiAutomator

运行Espresso和UI Automator测试时要使用模拟器。国内手机的ROM大多进行过修改&#xff0c;可能加入很多限制&#xff0c;导致测试无法正常运行。 Espresso只支持一个活动内部交互行为的测试。跨越多个活动、多个应用的场景需要使用UI Automator。使用Espresso和UI Automator的…...

stm32标准库开发常用函数的使用和代码说明

文章目录 GPIO&#xff08;General Purpose Input/Output&#xff09;NVIC&#xff08;Nested Vectored Interrupt Controller&#xff09;DMA&#xff08;Direct Memory Access&#xff09;USART&#xff08;Universal Synchronous/Asynchronous Receiver/Transmitter&#xf…...

有关合泰BA45F5260中断的思考

最近看前辈写的代码&#xff0c;发现这样一段代码&#xff1a; #ifdef SUPPORT_RF_NET_FUNCTION if(UART_INT_is_L()) { TmrInsertTimer(eTmrHdlUartRxDelay,TMR_PERIOD(2000),NULL); break; } #endif 其中UART_INT_is_L&am…...

Numpy-算数函数与数学函数

⛳算数函数 如果参与运算的两个对象都是ndarray&#xff0c;并且形状相同&#xff0c;那么会对位彼此之间进 第 30 页 行&#xff08; - * /&#xff09;运算。NumPy 算术函数包含简单的加减乘除: add()&#xff0c;subtract()&#xff0c;multiply() 和divide()。 &#x1f…...

Nginx在springboot中起到的作用

面试时这样回答&#xff1a; 在Spring Boot项目中使用Nginx可以有以下用途&#xff1a; 1. 反向代理&#xff1a;Nginx可以作为反向代理服务器&#xff0c;将外部请求转发到后端的Spring Boot应用&#xff0c;并可以实现负载均衡、高可用、缓存等功能&#xff0c;提高系统的性…...

12.(开发工具篇vscode+git)vscode 不能识别npm命令

1&#xff1a;vscode 不能识别npm命令 问题描述&#xff1a; 解决方式&#xff1a; &#xff08;1&#xff09;右击VSCode图标&#xff0c;选择以管理员身份运行&#xff1b; &#xff08;2&#xff09;在终端中执行get-ExecutionPolicy&#xff0c;显示Restricted&#xff…...

如何在MacBook上彻底删除mysql

好久以前安装过&#xff0c;但是现在配置mysql一直出错&#xff0c;索性全部删掉重新配置。 一、停止MySQL服务 首先&#xff0c;请确保 MySQL 服务器已经停止运行&#xff0c;以免影响后续的删除操作。 sudo /usr/local/mysql/support-files/mysql.server stop如果你输入之…...

web攻击面试|网络渗透面试(一)

Web攻击面试大纲 常见Web攻击类型 1.1 SQL注入攻击 1.2 XSS攻击 1.3 CSRF攻击 1.4 命令注入攻击SQL注入攻击 2.1 基本概念 2.2 攻击原理 2.3 防御措施XSS攻击 3.1 基本概念 3.2 攻击原理 3.3 防御措施CSRF攻击 4.1 基本概念 4.2 攻击原理 4.3 防御措施命令注入攻击 5.1 基本概…...

VBA操作WORD(六)另存为不含宏的文档

Sub 另存为不含宏的文档()Application.DisplayAlerts False Application.ScreenUpdating FalseDim oDoc As DocumentSet oDoc Word.ActiveDocumentDim oRng As RangeSet oRng oDoc.ContentDim sPath As String默认存储路径&#xff0c;当前用户桌面&#xff0c;注释掉的是当…...

分享69个Java源码,总有一款适合您

Java源码 分享69个Java源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c;大家下载后可以看到。 源码下载链接&#xff1a; https://pan.baidu.com/s/1ZgbJhMNwIyFyqFzHsDdL5w 提取码&a…...

《cool! autodistill帮你标注数据训练yolov8模型》学习笔记

《cool! autodistill帮你标注数据训练yolov8模型》 Summary Autodistill是一个用于自动标注数据训练边缘模型的工具。 Highlights &#x1f4a1; Autodistill由Robotflow推出&#xff0c;用于训练建立部署计算机视觉模型。&#x1f4bb; 通过使用大模型自动标注和训练小模型…...

Rust vs Go:常用语法对比(十)

题图来自 Rust vs. Golang: Which One is Better?[1] 182. Quine program Output the source of the program. 输出程序的源代码 package mainimport "fmt"func main() { fmt.Printf("%s%c%s%c\n", s, 0x60, s, 0x60)}var s package mainimport "fm…...

SliverPersistentHeader组件 实现Flutter吸顶效果

效果&#xff1a; 20230723-212152-73_Trim 代码&#xff1a; import package:flutter/cupertino.dart; import package:flutter/material.dart;class StickHeaderPage extends StatefulWidget {overrideState<StatefulWidget> createState() {// TODO: implement creat…...

Nginx性能优化配置

一、全局优化 # 工作进程数 worker_processes auto; # 建议 CPU核心数|CPU线程数# 最大支持的连接(open-file)数量&#xff1b;最大值受限于 Linux open files (ulimit -n) # 建议公式&#xff1a;worker_rlimit_nofile > worker_processes * worker_connections…...

Python爬虫实战:研究feedparser库相关技术

1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...

【机器视觉】单目测距——运动结构恢复

ps&#xff1a;图是随便找的&#xff0c;为了凑个封面 前言 在前面对光流法进行进一步改进&#xff0c;希望将2D光流推广至3D场景流时&#xff0c;发现2D转3D过程中存在尺度歧义问题&#xff0c;需要补全摄像头拍摄图像中缺失的深度信息&#xff0c;否则解空间不收敛&#xf…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

【python异步多线程】异步多线程爬虫代码示例

claude生成的python多线程、异步代码示例&#xff0c;模拟20个网页的爬取&#xff0c;每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程&#xff1a;允许程序同时执行多个任务&#xff0c;提高IO密集型任务&#xff08;如网络请求&#xff09;的效率…...

MySQL用户和授权

开放MySQL白名单 可以通过iptables-save命令确认对应客户端ip是否可以访问MySQL服务&#xff1a; test: # iptables-save | grep 3306 -A mp_srv_whitelist -s 172.16.14.102/32 -p tcp -m tcp --dport 3306 -j ACCEPT -A mp_srv_whitelist -s 172.16.4.16/32 -p tcp -m tcp -…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

在 Visual Studio Code 中使用驭码 CodeRider 提升开发效率:以冒泡排序为例

目录 前言1 插件安装与配置1.1 安装驭码 CodeRider1.2 初始配置建议 2 示例代码&#xff1a;冒泡排序3 驭码 CodeRider 功能详解3.1 功能概览3.2 代码解释功能3.3 自动注释生成3.4 逻辑修改功能3.5 单元测试自动生成3.6 代码优化建议 4 驭码的实际应用建议5 常见问题与解决建议…...

leetcode_69.x的平方根

题目如下 &#xff1a; 看到题 &#xff0c;我们最原始的想法就是暴力解决: for(long long i 0;i<INT_MAX;i){if(i*ix){return i;}else if((i*i>x)&&((i-1)*(i-1)<x)){return i-1;}}我们直接开始遍历&#xff0c;我们是整数的平方根&#xff0c;所以我们分两…...

Tauri2学习笔记

教程地址&#xff1a;https://www.bilibili.com/video/BV1Ca411N7mF?spm_id_from333.788.player.switch&vd_source707ec8983cc32e6e065d5496a7f79ee6 官方指引&#xff1a;https://tauri.app/zh-cn/start/ 目前Tauri2的教程视频不多&#xff0c;我按照Tauri1的教程来学习&…...