当前位置: 首页 > news >正文

线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】

文章目录

    • 3.1、线性回归
      • 3.1.1、PyTorch 从零实现线性回归
      • 3.1.2、简单实现线性回归

在这里插入图片描述

3.1、线性回归

线性回归是显式解,深度学习中绝大多数遇到的都是隐式解。

3.1.1、PyTorch 从零实现线性回归

%matplotlib inline
import random
import torch
#d2l库中的torch模块,并将其用别名d2l引用。d2l库是《动手学深度学习》(Dive into Deep Learning)这本书的配套库,包含了一些自定义的函数和工具,以及对PyTorch库的包装和扩展。
from d2l import torch as d2l

生成数据集及标签

def synthetic_data(w,b,num_examples):"""生成 y = Xw + b + 噪声"""X = torch.normal(0,1,(num_examples,len(w)))#创建一个大小为(num_examples, len(w))的张量X,并使用均值为0,标准差为1的正态分布对其进行初始化。这个张量代表输入特征,其中 num_examples 是样本数量,len(w) 是特征向量的长度。y = torch.matmul(X,w) + by += torch.normal(0, 0.01, y.shape)#预测值y中添加一个均值为0,标准差为0.01的正态分布噪声,以增加模型的随机性和泛化能力。return X, y.reshape((-1,1))#预测值y通过reshape方法被转换成一个列向量
true_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = synthetic_data(true_w,true_b,1000)print('features:',features[0],'\nlabel:',labels[0])d2l.set_figsize()#设置图表尺寸
d2l.plt.scatter(features[:,1].detach().numpy(),labels.detach().numpy(),1);           

d2l.plt.scatter(,,),使用d2l库中的绘图函数来创建散点图。

这个函数接受三个参数:

  • features[:,1].detach().numpy() 是一个二维张量features的切片操作,选择了所有行的第二列数据。detach()函数用于将张量从计算图中分离,numpy()方法将张量转换为NumPy数组。这样得到的是一个NumPy数组,代表散点图中的x轴数据。

  • labels.detach().numpy() 是一个二维张量labels的分离和转换操作,得到一个NumPy数组,代表散点图中的y轴数据。

  • 1 是可选参数,用于设置散点的标记尺寸。在这里,设置为1表示每个散点的大小为1个点。

这里为什么要用detach()?

尝试去掉后结果是不变的,应对某些pytorch版本转numpy必须这样做。

def data_iter(batch_size, features, labels):num_examples = len(features)#创建一个包含0到num_examples-1的整数列表,表示样本索引。indices = list(range(num_examples))#随机打乱样本索引的顺序,样本是随机读取的,没有特定顺序。random.shuffle(indices)for i in range(0, num_examples, batch_size):# 根据当前批次的起始索引,创建一个包含当前批次样本索引的张量。min(i + batch_size, num_examples)确保最后一个批次的大小不超过剩余样本数量。batch_indices = torch.tensor(indices[i:min(i + batch_size,num_examples)])# 使用生成器返回当前批次的特征和标签。yield features[batch_indices],labels[batch_indices]batch_size = 10for X,y in data_iter(batch_size, features, labels):print(X,'\n',y)break

小插叙,synthetic_data()返回值中X敲成了小写,直接导致后面矩阵乘法形状对不上,找了半天错误。

yield 预备知识:

当一个函数包含 yield 语句时,它就变成了一个生成器函数。生成器函数用于生成一个序列的值,而不是一次性返回所有值。每次调用生成器函数时,它会暂停执行,并返回一个值。当下一次调用生成器函数时,它会从上次暂停的地方继续执行,直到遇到下一个 yield 语句或函数结束。

定义初始化模型参数

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1,  requires_grad=True)

定义模型

def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b  #X, w进行矩阵乘法

定义损失函数

def squared_loss(y_hat,y): #(预测值,真实值)"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) **2 / 2 

这就是数据是张量的好处,

M S E ( y , y ′ ) = ∑ i = 1 n ( y i − y i ′ ) 2 n MSE(y,y') = \frac{\sum_{i=1}^n(y_i-y_i')^2}{n} MSE(y,y)=ni=1n(yiyi)2

明明是含有求和操作的数学公式,在张量面前形同虚设,代码实现是这么简单。就像在写标量公式一样。

定义优化算法

def sgd(params, lr, batch_size):#一个包含待更新参数的列表,学习率,每个小批次中的样本数量)"""小批量随机梯度下降"""with torch.no_grad():for param in params:param -=lr * param.grad / batch_sizeparam.grad.zero_()

为什么执行的减法而不是加法?

梯度的负方向

优化算法是怎么跟损失函数合作来完成参数优化?

优化函数没有直接使用损失值,但通过使用损失函数和反向传播计算参数的梯度,并将这些梯度应用于参数更新,间接地优化了模型的损失。梯度下降算法利用了参数的梯度信息来更新参数,以使损失函数尽可能减小。

优化算法(例如随机梯度下降)是怎么拿到损失函数的梯度信息的?

损失函数梯度完整的说是 loss关于x,loss关于y的梯度 ,搞清楚这个概念就不难理解了【初学时,我误解成了损失值的梯度x,y的梯度是两个概念,显然后者是非常不准确的表述】,损失函数梯度就在 sgd的params中。

l = loss(net(X, w, b), y) 
l.sum().backward()#此时损失函数梯度【关于w,b的梯度】存在w.grad,b.grad中
sgd([w,b], lr, batch_size) #使用参数梯度更新参数

param.grad.zero_()在这里有什么意义?谁会干扰梯度的求解?

如果在循环的下一次迭代中不使用param.grad.zero_()来清零参数的梯度,那么参数将会保留上一次迭代计算得到的梯度值,继续沿用该梯度值来求解梯度。就是说上次for循环的param会对下次param的梯度求解产生影响,所以才要清空梯度。

训练过程

#超参数
lr =0.03 #学习率(learning rate),控制每次参数更新的步幅大小。
num_epochs = 3 #数据集的扫描次数,即要重复训练模型的次数。
net =linreg #表示模型,这里使用了一个名为linreg的线性回归模型。
loss = squared_loss#表示损失函数,这里使用了一个名为squared_loss的均方损失函数。for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y) # X 、y 的小批量损失#l形状是 (batch_size, 1),非标量l.sum().backward()sgd([w,b], lr, batch_size) #使用参数梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b),labels)print('epoch ',epoch+1,'loss ',float(train_l.mean()))
epoch  1 loss  0.032808780670166016
epoch  2 loss  0.00011459046800155193
epoch  3 loss  5.012870315113105e-05

with torch.no_grad()有什么意义,毕竟没有backward()操作?

对于with torch.no_grad()块,在 PyTorch 中禁用梯度追踪和计算图的构建。在该块中执行的操作不会被记录到计算图中,因此不会生成梯度信息。其作用是告诉 PyTorch 不要跟踪计算梯度,这样可以节省计算资源。

简单说,就是计算损失值的张量运算不会记录到计算图中,因为没必要,而且不建立计算图,求损失值更快了。

代码存在的小问题

最后一批次可能不足batch_size,sgd 执行 param -=lr * param.grad / batch_size取平均是有问题的,修改后:

sgd([w,b], lr,min(batch_size, X.shape[0])) #使用参数梯度更新参数

比较真实参数与训练学到的参数评估训练成功程度

print('w的估计误差:',true_w - w.reshape(true_w.shape))
print('b的估计误差:',true_b - b)
w的估计误差: tensor([-6.1035e-05,  2.5797e-04], grad_fn=<SubBackward0>)
b的估计误差: tensor([0.0018], grad_fn=<RsubBackward1>)

3.1.2、简单实现线性回归

生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2,-3.4])
true_b = 4.2
#d2l 的人造数据集函数
features, labels = d2l.synthetic_data(true_w, true_b,1000)

读取数据集

def load_array(data_arrays, batch_size, is_train=True):"""构造一个Pytorch数据迭代器"""#PyTorch提供的一个用于封装多个张量数据的数据集对象,*data_arrays用于将数据数组解包为多个参数。#*data_arrays解包等价于 dataset = data.TensorDataset(features, labels)dataset = data.TensorDataset(*data_arrays)#PyTorch提供的一个用于批量加载数据的迭代器return data.DataLoader(dataset, batch_size, shuffle= is_train)batch_size = 10
data_iter = load_array((features,labels), batch_size)
#iter() 函数将数据迭代器转换为迭代器对象,而 next() 函数用于获取迭代器的下一个元素。
next(iter(data_iter))

解包操作(见 python 预备知识)

星号 *dataset = data.TensorDataset(*data_arrays) 中的作用是将元组或列表中的元素解包,并作为独立的参数传递给函数或构造函数。这样可以更方便地传递多个参数。

迭代器使用(见 python 预备知识)

iter() 函数的主要目的是将可迭代对象转换为迭代器对象,以便于使用 next() 函数逐个访问其中的元素。

使用框架预定好的层

from torch import nn
#线性回归就是一个简单的单层神经网络
#一个全连接层,它接受大小为 2 的输入特征,并输出大小为 1 的特征。
net = nn.Sequential(nn.Linear(2,1))

初始化模型参数

#net[0] 表示模型中的第一个层,weight权重参数,正态分布初始化
net[0].weight.data.normal_(0,0.01)
#第一层加入偏差
net[0].bias.data.fill_(0)

实例化损失函数

loss = nn.MSELoss()

实例化优化算法( SGD)

#net.parameters() 返回一个迭代器,该迭代器包含了模型中所有可训练的参数。
trainer = torch.optim.SGD(net.parameters(),lr=0.03)

训练过程

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features),labels)print(f'epoch {epoch+1}, loss {l:f}')

关于输出格式,最后一个明显最好

print('epoch ',epoch+1,',loss ',l)       #epoch  1 ,loss  tensor(9.9119e-05, grad_fn=<MseLossBackward0>)print('epoch ',epoch+1,',loss ',float(l))# epoch  1 ,loss  9.872819646261632e-05print(f'epoch {epoch+1}, loss {l:f}')    # epoch 1, loss 0.000099

还可以自定义保留有限小数位

print(f'epoch {epoch+1}, loss {l:.4f}')# 保留4位。

相关文章:

线性神经网路——线性回归随笔【深度学习】【PyTorch】【d2l】

文章目录 3.1、线性回归3.1.1、PyTorch 从零实现线性回归3.1.2、简单实现线性回归 3.1、线性回归 线性回归是显式解&#xff0c;深度学习中绝大多数遇到的都是隐式解。 3.1.1、PyTorch 从零实现线性回归 %matplotlib inline import random import torch #d2l库中的torch模块&a…...

js实现多种按钮

你可以使用JavaScript来实现多种类型的按钮&#xff0c;以下是几个常见的示例&#xff1a; 普通按钮&#xff08;Normal Button&#xff09;&#xff1a; <button>Click me</button> 带图标的按钮&#xff08;Button with Icon&#xff09;&#xff1a; <bu…...

getopt函数(未更新完)

2023年7月28日&#xff0c;周五上午 这是我目前碰到过的比较复杂的函数之一&#xff0c; 为了彻底弄懂这个函数&#xff0c;我花了几个小时。 为了更好的说明这个函数&#xff0c;之后我可能会录制讲解视频并上传到B站&#xff0c; 如果我上传到B站&#xff0c;我会在文章添…...

SpringCloud学习路线(9)——服务异步通讯RabbitMQ

一、初见MQ &#xff08;一&#xff09;什么是MQ&#xff1f; MQ&#xff08;MessageQueue&#xff09;&#xff0c;意思是消息队列&#xff0c;也就是事件驱动架构中的Broker。 &#xff08;二&#xff09;同步调用 1、概念&#xff1a; 同步调用是指&#xff0c;某一服务…...

postcss-pxtorem适配插件动态配置rootValue(根据文件路径名称,动态改变vue.config里配置的值)

项目背景&#xff1a;一个项目里有两个分辨率的设计稿(1920和2400)&#xff0c;不能拆开来打包 参考&#xff1a; 是参考vant插件&#xff1a;移动端Vant组件库rem适配下大小异常的解决方案&#xff1a;https://github.com/youzan/vant/issues/1181 说明&#xff1a; 因为vue.c…...

代码随想录算法训练营第二十三天 | 额外题目系列

额外题目 1365. 有多少小于当前数字的数字借着本题&#xff0c;学习一下各种排序未看解答自己编写的青春版重点代码随想录的代码我的代码(当天晚上理解后自己编写) 941.有效的山脉数组未看解答自己编写的青春版重点代码随想录的代码我的代码(当天晚上理解后自己编写) 1207. 独一…...

UiAutomator

运行Espresso和UI Automator测试时要使用模拟器。国内手机的ROM大多进行过修改&#xff0c;可能加入很多限制&#xff0c;导致测试无法正常运行。 Espresso只支持一个活动内部交互行为的测试。跨越多个活动、多个应用的场景需要使用UI Automator。使用Espresso和UI Automator的…...

stm32标准库开发常用函数的使用和代码说明

文章目录 GPIO&#xff08;General Purpose Input/Output&#xff09;NVIC&#xff08;Nested Vectored Interrupt Controller&#xff09;DMA&#xff08;Direct Memory Access&#xff09;USART&#xff08;Universal Synchronous/Asynchronous Receiver/Transmitter&#xf…...

有关合泰BA45F5260中断的思考

最近看前辈写的代码&#xff0c;发现这样一段代码&#xff1a; #ifdef SUPPORT_RF_NET_FUNCTION if(UART_INT_is_L()) { TmrInsertTimer(eTmrHdlUartRxDelay,TMR_PERIOD(2000),NULL); break; } #endif 其中UART_INT_is_L&am…...

Numpy-算数函数与数学函数

⛳算数函数 如果参与运算的两个对象都是ndarray&#xff0c;并且形状相同&#xff0c;那么会对位彼此之间进 第 30 页 行&#xff08; - * /&#xff09;运算。NumPy 算术函数包含简单的加减乘除: add()&#xff0c;subtract()&#xff0c;multiply() 和divide()。 &#x1f…...

Nginx在springboot中起到的作用

面试时这样回答&#xff1a; 在Spring Boot项目中使用Nginx可以有以下用途&#xff1a; 1. 反向代理&#xff1a;Nginx可以作为反向代理服务器&#xff0c;将外部请求转发到后端的Spring Boot应用&#xff0c;并可以实现负载均衡、高可用、缓存等功能&#xff0c;提高系统的性…...

12.(开发工具篇vscode+git)vscode 不能识别npm命令

1&#xff1a;vscode 不能识别npm命令 问题描述&#xff1a; 解决方式&#xff1a; &#xff08;1&#xff09;右击VSCode图标&#xff0c;选择以管理员身份运行&#xff1b; &#xff08;2&#xff09;在终端中执行get-ExecutionPolicy&#xff0c;显示Restricted&#xff…...

如何在MacBook上彻底删除mysql

好久以前安装过&#xff0c;但是现在配置mysql一直出错&#xff0c;索性全部删掉重新配置。 一、停止MySQL服务 首先&#xff0c;请确保 MySQL 服务器已经停止运行&#xff0c;以免影响后续的删除操作。 sudo /usr/local/mysql/support-files/mysql.server stop如果你输入之…...

web攻击面试|网络渗透面试(一)

Web攻击面试大纲 常见Web攻击类型 1.1 SQL注入攻击 1.2 XSS攻击 1.3 CSRF攻击 1.4 命令注入攻击SQL注入攻击 2.1 基本概念 2.2 攻击原理 2.3 防御措施XSS攻击 3.1 基本概念 3.2 攻击原理 3.3 防御措施CSRF攻击 4.1 基本概念 4.2 攻击原理 4.3 防御措施命令注入攻击 5.1 基本概…...

VBA操作WORD(六)另存为不含宏的文档

Sub 另存为不含宏的文档()Application.DisplayAlerts False Application.ScreenUpdating FalseDim oDoc As DocumentSet oDoc Word.ActiveDocumentDim oRng As RangeSet oRng oDoc.ContentDim sPath As String默认存储路径&#xff0c;当前用户桌面&#xff0c;注释掉的是当…...

分享69个Java源码,总有一款适合您

Java源码 分享69个Java源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c;大家下载后可以看到。 源码下载链接&#xff1a; https://pan.baidu.com/s/1ZgbJhMNwIyFyqFzHsDdL5w 提取码&a…...

《cool! autodistill帮你标注数据训练yolov8模型》学习笔记

《cool! autodistill帮你标注数据训练yolov8模型》 Summary Autodistill是一个用于自动标注数据训练边缘模型的工具。 Highlights &#x1f4a1; Autodistill由Robotflow推出&#xff0c;用于训练建立部署计算机视觉模型。&#x1f4bb; 通过使用大模型自动标注和训练小模型…...

Rust vs Go:常用语法对比(十)

题图来自 Rust vs. Golang: Which One is Better?[1] 182. Quine program Output the source of the program. 输出程序的源代码 package mainimport "fmt"func main() { fmt.Printf("%s%c%s%c\n", s, 0x60, s, 0x60)}var s package mainimport "fm…...

SliverPersistentHeader组件 实现Flutter吸顶效果

效果&#xff1a; 20230723-212152-73_Trim 代码&#xff1a; import package:flutter/cupertino.dart; import package:flutter/material.dart;class StickHeaderPage extends StatefulWidget {overrideState<StatefulWidget> createState() {// TODO: implement creat…...

Nginx性能优化配置

一、全局优化 # 工作进程数 worker_processes auto; # 建议 CPU核心数|CPU线程数# 最大支持的连接(open-file)数量&#xff1b;最大值受限于 Linux open files (ulimit -n) # 建议公式&#xff1a;worker_rlimit_nofile > worker_processes * worker_connections…...

杭州多校2023“钉耙编程”中国大学生算法设计超级联赛(4)

1003Simple Set Problem 首先将元素的值 x 以及所属集合的编号 y 作为二元组 (x,y) 存入数组&#xff0c;然后按照 x 升序排列&#xff0c; 之后使用双指针扫描数组&#xff08;尺取法&#xff09;&#xff0c;当区间内出现了所有编号时更新答案的最小值&#xff0c; #includ…...

音视频入门之音频采集、编码、播放

作者&#xff1a;花海blog 今天我们学习音频的采集、编码、生成文件、转码等操作&#xff0c;我们生成三种格式的文件格式&#xff0c;pcm、wav、aac 三种格式&#xff0c;并且我们用 AudioStack 来播放音频&#xff0c;最后我们播放这个音频。 使用 AudioRecord 实现录音生成…...

在 Linux 系统中,如何发起POST/GET请求

在 Linux 系统中&#xff0c;可以使用命令行工具 curl 或者 wget 来发送 POST 请求。这两个工具都是非常常用的命令行工具&#xff0c;可以通过命令行直接发送 HTTP 请求。 1. 使用 curl 发送 POST 请求&#xff1a; curl -X POST -H "Content-Type: application/json&q…...

文心一言大数据模型-文心千帆大模型平台

官网&#xff1a; 文心千帆大模型平台 (baidu.com) 文心千帆大模型 (baidu.com) 模型优势 1、模型效果优&#xff1a;所需标注数据少&#xff0c;在各场景上的效果处于业界领先水平 2、生成能力强&#xff1a;拥有丰富的AI内容生成&#xff08;AIGC&#xff09;能力 3、应用…...

django学习笔记(1)

django创建项目 先创建一个文件夹用来放django的项目&#xff0c;我这里是My_Django_it 之后打开到该文件下&#xff0c;并用下面的指令来创建myDjango1项目 D:\>cd My_Django_itD:\My_Django_it>"D:\zzu_it\Django_learn\Scripts\django-admin.exe" startpr…...

postgresql主从搭建

postgresql主从搭建 主从服务器分别安装好postgresql 主库 创建数据库热备帐号replica&#xff0c;密码123456为例&#xff0c;则执行以下命令 create role replica login replication encrypted password 123456;打开 pg_hba.conf 配置文件&#xff0c;设置 replica 用户白…...

将Parasoft和ChatGPT相结合会如何?

ChatGPT是2023年最热门的话题之一&#xff0c;是OpenAI训练的语言模型。它能够理解和生成自然语言文本&#xff0c;并接受过大量数据的训练&#xff0c;包括用各种编程语言编写的许多开源项目的源代码。 软件开发人员可以利用大量的知识库来协助他们的工作&#xff0c;因为它具…...

Go text/template详解:使用指南与最佳实践

I. 简介 A. 什么是 Go text/template Go text/template 是 Go 语言标准库中的一个模板引擎&#xff0c;用于生成文本输出。它使用类似于 HTML 的模板语言&#xff0c;可以将数据和模板结合起来&#xff0c;生成最终的文本输出。 B. Go text/template 的优点 Go text/templa…...

Stable Diffusion在各种显卡上的加速方式测试,最高可以提速211.2%

Stable Diffusion是一种基于扩散模型的图像生成技术&#xff0c;能够从文本生成高质量的图像&#xff0c;适用于CG&#xff0c;插图和高分辨率壁纸等领域。 但是它计算过程复杂&#xff0c;使得它的生成速度较慢。所以研究人员就创造了各种提高其速度的方式&#xff0c;比如Xf…...

Java读取外链图片忽略ssl验证转为base64

最近在对接外部接口时遇到返回的图片所在的服务器全都没有ssl证书&#xff0c;导致在前端直接用img标签展示时图片开裂。于是转为通过后端获取&#xff0c;绕过ssl验证之后转为base64返回。记录一下代码段。 package com.sy.ai.common.utils;import cn.hutool.core.codec.Base…...