当前位置: 首页 > news >正文

1-4 动手学深度学习v2-线性回归的简洁实现-笔记

通过使用深度学习框架来简洁地实现 线性回归模型 生成数据集

import numpy as np
import torch
from torch.utils import data  # 从torch.utils中引入一些处理数据的模块
from d2l import torch as d2ltrue_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

调用框架中现有的API来读取数据

# 假设我们已经有了features和labels,我们可以把它们作为一个list传到TensorDataset里面
# 会得到一个Pytorch的一个dataset
# dataset拿到数据集之后,我们可以调用data.DataLoader这个函数,每一次从里面随机的挑选batch_size个样本
# shuffle表示是不是要随机去打乱它的顺序def load_array(data_arrays, batch_size, is_train=True):# 构造一个PyTorch数据迭代器dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)next(iter(data_iter))
# 先把data_iter转成python iter,通过next函数来得到一个X和一个y

控制台输出:

[tensor([[ 0.7835,  2.1334],[ 1.2522,  0.8810],[ 2.7299, -0.8198],[-0.7226,  0.1075],[ 0.1661,  3.0323],[ 1.5635,  0.4247],[-0.1221,  0.3311],[ 0.3448,  0.8430],[-0.9672,  1.3701],[ 0.8360,  2.6205]]),tensor([[-1.4739],[ 3.7078],[12.4481],[ 2.4049],[-5.7701],[ 5.8812],[ 2.8200],[ 2.0200],[-2.3958],[-3.0526]])]

使用框架预定义好的层

# nn是神经网络的缩写
from torch import nn
# nn 中定义了大量的定义好的层 对我们的线性回归来说 等价于它的线性层或者说全连接层net = nn.Sequential(nn.Linear(2,1))  # 唯一需要指定的是 输入的维度是多少 输出的维度是多少
# 输入维度:每个样本它的特征个数,即为输入维度# 一般直接用nn.Linear就行(线性回归就是一个简单的单层神经网络),
# 但是我们为了后面的方便,把它放到一个容器Sequential里面
# 可以理解容器Sequential为 list of layers 我们把层按顺序 一个一个的放到一起

在这段代码中,nn.Linear(2,1) 创建了一个线性层(或称为全连接层),这是神经网络中最基本的组成单元之一。
nn.Linear 需要两个主要的参数:输入维度和输出维度,这两个参数决定了层的结构。

输入维度(Input Dimension)

  • 输入维度指的是每个输入数据向量的大小或长度。在神经网络中,每个样本通常被表示为一个向量,向量中的每个元素对应一个特征。因此,输入维度就是每个样本的特征数量
  • 例如,nn.Linear(2,1) 中的 2 表示每个输入样本有2个特征。如果你的数据是二维空间中的点,那么每个点由两个坐标值表示(例如,x 和 y),因此输入维度是2

输出维度(Output Dimension)

  • 输出维度指的是经过这一层计算后输出数据向量的大小或长度。输出维度取决于你希望这一层神经网络输出多少个数值
  • nn.Linear(2,1)中,1 表示这个线性层的输出是一个单一的值。这常见于执行回归任务(如线性回归)时,你可能只需要预测一个连续值(比如房价),因此输出维度是1

nn.Sequential 容器

nn.Sequential 是一个容器,用于按顺序包装多个层。它允许你将多个层组合成一个模块,这样数据就可以依次通过这些层进行处理。在这个例子中,虽然只有一个 nn.Linear(2,1) 层,将其放入 nn.Sequential 中可能看起来多此一举,但这种做法为添加更多层提供了灵活性,便于后续模型的扩展。


初始化模型参数

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
# 因为我是放在容器里,也就它一个模型,所以可以用索引0访问到
# weight w
# bias b
# data 真实data
# normal_ 使用正态分布来填充data的值 均值为0,标准差为0.01

计算均方误差使用的是MSELoss类,也称为平方 L 2 L_{2} L2范数

loss = nn.MSELoss()

实例化SGD实例

trainer = torch.optim.SGD(net.parameters(), lr=0.03)
# net.parameters() w和b

那么相比从0开始实现,那里的除以batch_size这里怎么没有了呢?

在PyTorch中,损失函数的默认行为对于批处理的数据是计算并返回批处理中所有样本损失的平均值。因此,当你在训练循环中调用loss(net(X), y)时,得到的损失l实际上已经是当前批次内所有样本损失的平均值

这意味着,即使在代码中没有显式地除以批大小(batch_size)来计算平均误差,这个计算过程实际上已经在损失函数内部自动完成了。这是为了简化训练过程,并使代码更加简洁。

例如,如果你使用的是torch.nn.CrossEntropyLosstorch.nn.MSELoss作为你的损失函数,这些函数默认就是计算批次中所有样本损失的平均值。如果你希望损失函数返回批次中所有样本损失的总和,而不是平均值,你可以在初始化损失函数时通过设置reduction='sum'参数来实现。默认情况下,reduction参数的值是'mean',即计算平均值。

总结来说,你看不到代码中显式地除以batch_size来计算平均误差的原因是因为这一计算步骤已经被内嵌在了损失函数中。


训练过程代码与我们从零开始实现时所做的非常相似

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:    # 依次拿出小批次l = loss(net(X) ,y)   # net本身自带模型参数了,不需要显示的去写w和b了,# 而且这里算的确实已经是平均值,相当于上一节从0开始里面,除以batch_size的操作,它这里是放在计算损失函数里面已经做了trainer.zero_grad()   # 梯度清零 防止梯度累加l.backward()          # 这里pytorch 自动会计算sum,不需要我们自己再去手动计算sumtrainer.step()        # 模型更新l = loss(net(features), labels) # 扫完一遍数据之后,把整个特征矩阵带进去,评估效果,算损失值print(f'epoch {epoch + 1}, loss {l:f}')

控制台输出:

epoch 1, loss 0.000251
epoch 2, loss 0.000111
epoch 3, loss 0.000110

我们可以发现,随着轮次的增大,损失值在降低。

相关文章:

1-4 动手学深度学习v2-线性回归的简洁实现-笔记

通过使用深度学习框架来简洁地实现 线性回归模型 生成数据集 import numpy as np import torch from torch.utils import data # 从torch.utils中引入一些处理数据的模块 from d2l import torch as d2ltrue_w torch.tensor([2,-3.4]) true_b 4.2 features, labels d2l.syn…...

SQL如何实现数据表行转列、列转行?

SQL行转列、列转行可以帮助我们更方便地处理数据,生成需要的报表和结果集。本文将介绍在SQL中如何实现数据表地行转列、列转行操作,以及实际应用示例。 这里通过表下面三张表进行举例 SQL创建数据库和数据表 数据表示例数据分别如下: data_…...

【React】redux状态管理、react-redux状态管理高级封装模块化

【React】react组件传参、redux状态管理 一、redux全局状态管理1、redux概述2、redux的组成1.1 State-状态1.2 Action-事件1.3 Reducer1.4 Store 3、redux入门案例1.1 前期准备1.2 构建store1.2.1 在src下新建store文件夹1.2.2 在store文件夹下新建index.ts文件1.2.3 在index.t…...

HAProxy 和负载均衡概念简介

简介 HAProxy,全称高可用代理,是一款流行的开源软件 TCP/HTTP 负载均衡器和代理解决方案,可在 Linux、macOS 和 FreeBSD 上运行。它最常见的用途是通过将工作负载分布到多台服务器(例如 Web、应用程序、数据库)上来提…...

【go】ent操作之CRUD与联表查询

文章目录 1 CRUD1.1 创建1.1.1 单条创建1.1.2 批量创建 1.2 查找1.2.1 查询单条 / 条件准确查询1.2.2 查询单条 / 条件模糊查询1.2.3 查询单条 / In1.2.4 查询全部 1.3 更新1.4 删除 2 联表查询2.1 O2M(一对多查询)2.1.1 增加Edge2.1.2 查询方法2.1.2.1 …...

服务器性能监控管理方法及工具

服务器是组织数据中心的主干,无论是优化的用户体验,还是管理良好的资源,服务器都能为您完成所有工作,保持服务器随时可用和可访问对于面向业务的应用程序和服务以最佳水平运行至关重要。 理想的服务器性能需要主动监控物理和虚拟…...

AUTOSAR汽车电子嵌入式编程精讲300篇-基于FPGA和CAN协议2.0B的总线控制器研究与设计

目录 前言 研究现状分析 2 CAN总线协议 2.1 CAN总线基本概念 2.2 物理层...

14.1 Ajax与JSON应用(❤❤)

14.1 Ajax与JSON应用 1. Ajax1.1 简介1.2 Ajax使用流程1. 前端创建XMLHttpRequest对象2. 发送Ajax请求3. 处理服务器响应4. 代码2. JSON2.1 简介2.2 JS解析JSON3. Ajax与JSON开发3.1 后端:用Jackson实现JSON序列化输出3.2 前端Ajax处理JSON3.3 Ajax工具...

ffmpeg命令生成器

FFmpeg 快速入门:命令行详解、工具、教程、电子书 – 码中人的博客FFmpeg 是一个强大的命令行工具,可以用来处理音频、视频、字幕等多媒体文件。本文介绍了 FFmpeg 的基本用法、一些常用的命令行参数,以及常用的可视化工具。https://blog.mzh…...

JavaScript基础速成

由于学web时只学了后端,现在到了前后端联调的场景发现看不懂前端代码,于是开始恶补 看了下基础内容发现html和css比较好看懂,但JavaScript比较迷,大概知道组件id绑定事件 下面选取看菜鸟教程补充的JS知识 JS的作用 JS是在html…...

openGauss学习笔记-215 openGauss性能调优-确定性能调优范围-性能日志

文章目录 openGauss学习笔记-215 openGauss性能调优-确定性能调优范围-性能日志215.1 性能日志概述215.2 性能日志收集的配置参数 openGauss学习笔记-215 openGauss性能调优-确定性能调优范围-性能日志 215.1 性能日志概述 性能日志主要关注外部资源的访问性能问题。 性能日…...

在vs code的terminal,debug执行python main.py --train True

GPT4告诉我: 在VS Code中以debug状态执行带有参数(如--train)的main.py文件,你需要在launch.json配置文件中正确设置参数。以下是详细步骤: 打开你的main.py文件:确保你的main.py文件已经在VS Code中打开…...

docker 简单项目

要将服务器端口映射到容器端口,你可以使用 Docker 命令的 -p 选项。以下是基本的步骤: 1. **拉取镜像:** 在服务器上运行以下命令拉取你想要的 Docker 镜像,例如 Nginx: bash docker pull nginx 2. **运行容器…...

计算机毕业设计 基于SpringBoot的线上教育培训办公系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...

四、机器学习基础概念介绍

四、机器学习基础概念介绍 1_机器学习基础概念机器学习分类1.1 有监督学习1.2 无监督学习 2_有监督机器学习—常见评估方法数据集的划分2.1 留出法2.2 校验验证法(重点方法)简单交叉验证K折交叉验证(单独流出测试集)(常…...

Excel设置单元格下拉框(poi)

前言 年关在即,还在最后的迭代处理,还分了个其他同事的单,说是导出的Excel模版的2列要修改为下拉选项,过程很曲折,不说,以下其实就是一个笔记而已! 其实之前分享过阿里的EasyExcel设置单…...

api接口是什么意思,api接口该如何防护呢?

API接口:应用程序与服务之间的接口 什么是API接口 API是应用程序接口的缩写,指的是能够让不同的应用程序之间交换数据的一种方式。一个API接口就是应用程序与服务之间的接口,它定义了服务提供的功能和数据,以及应用程序如何访问这…...

PMP资料怎么学?PMP备考经验分享

PMP考试前大家大多都是提前备考个一两个月,但是有些朋友喜欢“不走寻常路”,并不打算去考PMP认证,想要单纯了解PMP,不管要不要考证,即使是仅仅学习了解一下我个人都非常支持,因为专业的基础的确能提高工作效…...

partition by list(msn_id)子句的含义

在数据库查询中,特别是在使用SQL语言时,"PARTITION BY" 子句用于对结果集进行分区,以便可以对每个分区进行单独的聚合操作。这是在执行窗口函数(如 ROW_NUMBER(), RANK(), SUM(), AVG() 等)时特别有用的。 …...

【C++】I/O多路转接详解(二)

在上一篇文章【C】I/O多路转接详解(一) 在出现EPOLL之后,随之而来的是两种事件处理模式的应运而生:Reator 和 Proactor,同步IO模型常用于Reactor模式,异步IO常用于Proactor. 目录 1. 服务器编程框架简介2. IO处理1. R…...

golang循环变量捕获问题​​

在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下: 问题背景 看这个代码片段: fo…...

VB.net复制Ntag213卡写入UID

本示例使用的发卡器:https://item.taobao.com/item.htm?ftt&id615391857885 一、读取旧Ntag卡的UID和数据 Private Sub Button15_Click(sender As Object, e As EventArgs) Handles Button15.Click轻松读卡技术支持:网站:Dim i, j As IntegerDim cardidhex, …...

VTK如何让部分单位不可见

最近遇到一个需求&#xff0c;需要让一个vtkDataSet中的部分单元不可见&#xff0c;查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行&#xff0c;是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示&#xff0c;主要是最后一个参数&#xff0c;透明度…...

Map相关知识

数据结构 二叉树 二叉树&#xff0c;顾名思义&#xff0c;每个节点最多有两个“叉”&#xff0c;也就是两个子节点&#xff0c;分别是左子 节点和右子节点。不过&#xff0c;二叉树并不要求每个节点都有两个子节点&#xff0c;有的节点只 有左子节点&#xff0c;有的节点只有…...

Typeerror: cannot read properties of undefined (reading ‘XXX‘)

最近需要在离线机器上运行软件&#xff0c;所以得把软件用docker打包起来&#xff0c;大部分功能都没问题&#xff0c;出了一个奇怪的事情。同样的代码&#xff0c;在本机上用vscode可以运行起来&#xff0c;但是打包之后在docker里出现了问题。使用的是dialog组件&#xff0c;…...

【Go语言基础【13】】函数、闭包、方法

文章目录 零、概述一、函数基础1、函数基础概念2、参数传递机制3、返回值特性3.1. 多返回值3.2. 命名返回值3.3. 错误处理 二、函数类型与高阶函数1. 函数类型定义2. 高阶函数&#xff08;函数作为参数、返回值&#xff09; 三、匿名函数与闭包1. 匿名函数&#xff08;Lambda函…...

搭建DNS域名解析服务器(正向解析资源文件)

正向解析资源文件 1&#xff09;准备工作 服务端及客户端都关闭安全软件 [rootlocalhost ~]# systemctl stop firewalld [rootlocalhost ~]# setenforce 0 2&#xff09;服务端安装软件&#xff1a;bind 1.配置yum源 [rootlocalhost ~]# cat /etc/yum.repos.d/base.repo [Base…...

MySQL 部分重点知识篇

一、数据库对象 1. 主键 定义 &#xff1a;主键是用于唯一标识表中每一行记录的字段或字段组合。它具有唯一性和非空性特点。 作用 &#xff1a;确保数据的完整性&#xff0c;便于数据的查询和管理。 示例 &#xff1a;在学生信息表中&#xff0c;学号可以作为主键&#xff…...

MyBatis中关于缓存的理解

MyBatis缓存 MyBatis系统当中默认定义两级缓存&#xff1a;一级缓存、二级缓存 默认情况下&#xff0c;只有一级缓存开启&#xff08;sqlSession级别的缓存&#xff09;二级缓存需要手动开启配置&#xff0c;需要局域namespace级别的缓存 一级缓存&#xff08;本地缓存&#…...

在鸿蒙HarmonyOS 5中使用DevEco Studio实现指南针功能

指南针功能是许多位置服务应用的基础功能之一。下面我将详细介绍如何在HarmonyOS 5中使用DevEco Studio实现指南针功能。 1. 开发环境准备 确保已安装DevEco Studio 3.1或更高版本确保项目使用的是HarmonyOS 5.0 SDK在项目的module.json5中配置必要的权限 2. 权限配置 在mo…...