当前位置: 首页 > news >正文

1-4 动手学深度学习v2-线性回归的简洁实现-笔记

通过使用深度学习框架来简洁地实现 线性回归模型 生成数据集

import numpy as np
import torch
from torch.utils import data  # 从torch.utils中引入一些处理数据的模块
from d2l import torch as d2ltrue_w = torch.tensor([2,-3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

调用框架中现有的API来读取数据

# 假设我们已经有了features和labels,我们可以把它们作为一个list传到TensorDataset里面
# 会得到一个Pytorch的一个dataset
# dataset拿到数据集之后,我们可以调用data.DataLoader这个函数,每一次从里面随机的挑选batch_size个样本
# shuffle表示是不是要随机去打乱它的顺序def load_array(data_arrays, batch_size, is_train=True):# 构造一个PyTorch数据迭代器dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)next(iter(data_iter))
# 先把data_iter转成python iter,通过next函数来得到一个X和一个y

控制台输出:

[tensor([[ 0.7835,  2.1334],[ 1.2522,  0.8810],[ 2.7299, -0.8198],[-0.7226,  0.1075],[ 0.1661,  3.0323],[ 1.5635,  0.4247],[-0.1221,  0.3311],[ 0.3448,  0.8430],[-0.9672,  1.3701],[ 0.8360,  2.6205]]),tensor([[-1.4739],[ 3.7078],[12.4481],[ 2.4049],[-5.7701],[ 5.8812],[ 2.8200],[ 2.0200],[-2.3958],[-3.0526]])]

使用框架预定义好的层

# nn是神经网络的缩写
from torch import nn
# nn 中定义了大量的定义好的层 对我们的线性回归来说 等价于它的线性层或者说全连接层net = nn.Sequential(nn.Linear(2,1))  # 唯一需要指定的是 输入的维度是多少 输出的维度是多少
# 输入维度:每个样本它的特征个数,即为输入维度# 一般直接用nn.Linear就行(线性回归就是一个简单的单层神经网络),
# 但是我们为了后面的方便,把它放到一个容器Sequential里面
# 可以理解容器Sequential为 list of layers 我们把层按顺序 一个一个的放到一起

在这段代码中,nn.Linear(2,1) 创建了一个线性层(或称为全连接层),这是神经网络中最基本的组成单元之一。
nn.Linear 需要两个主要的参数:输入维度和输出维度,这两个参数决定了层的结构。

输入维度(Input Dimension)

  • 输入维度指的是每个输入数据向量的大小或长度。在神经网络中,每个样本通常被表示为一个向量,向量中的每个元素对应一个特征。因此,输入维度就是每个样本的特征数量
  • 例如,nn.Linear(2,1) 中的 2 表示每个输入样本有2个特征。如果你的数据是二维空间中的点,那么每个点由两个坐标值表示(例如,x 和 y),因此输入维度是2

输出维度(Output Dimension)

  • 输出维度指的是经过这一层计算后输出数据向量的大小或长度。输出维度取决于你希望这一层神经网络输出多少个数值
  • nn.Linear(2,1)中,1 表示这个线性层的输出是一个单一的值。这常见于执行回归任务(如线性回归)时,你可能只需要预测一个连续值(比如房价),因此输出维度是1

nn.Sequential 容器

nn.Sequential 是一个容器,用于按顺序包装多个层。它允许你将多个层组合成一个模块,这样数据就可以依次通过这些层进行处理。在这个例子中,虽然只有一个 nn.Linear(2,1) 层,将其放入 nn.Sequential 中可能看起来多此一举,但这种做法为添加更多层提供了灵活性,便于后续模型的扩展。


初始化模型参数

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)
# 因为我是放在容器里,也就它一个模型,所以可以用索引0访问到
# weight w
# bias b
# data 真实data
# normal_ 使用正态分布来填充data的值 均值为0,标准差为0.01

计算均方误差使用的是MSELoss类,也称为平方 L 2 L_{2} L2范数

loss = nn.MSELoss()

实例化SGD实例

trainer = torch.optim.SGD(net.parameters(), lr=0.03)
# net.parameters() w和b

那么相比从0开始实现,那里的除以batch_size这里怎么没有了呢?

在PyTorch中,损失函数的默认行为对于批处理的数据是计算并返回批处理中所有样本损失的平均值。因此,当你在训练循环中调用loss(net(X), y)时,得到的损失l实际上已经是当前批次内所有样本损失的平均值

这意味着,即使在代码中没有显式地除以批大小(batch_size)来计算平均误差,这个计算过程实际上已经在损失函数内部自动完成了。这是为了简化训练过程,并使代码更加简洁。

例如,如果你使用的是torch.nn.CrossEntropyLosstorch.nn.MSELoss作为你的损失函数,这些函数默认就是计算批次中所有样本损失的平均值。如果你希望损失函数返回批次中所有样本损失的总和,而不是平均值,你可以在初始化损失函数时通过设置reduction='sum'参数来实现。默认情况下,reduction参数的值是'mean',即计算平均值。

总结来说,你看不到代码中显式地除以batch_size来计算平均误差的原因是因为这一计算步骤已经被内嵌在了损失函数中。


训练过程代码与我们从零开始实现时所做的非常相似

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:    # 依次拿出小批次l = loss(net(X) ,y)   # net本身自带模型参数了,不需要显示的去写w和b了,# 而且这里算的确实已经是平均值,相当于上一节从0开始里面,除以batch_size的操作,它这里是放在计算损失函数里面已经做了trainer.zero_grad()   # 梯度清零 防止梯度累加l.backward()          # 这里pytorch 自动会计算sum,不需要我们自己再去手动计算sumtrainer.step()        # 模型更新l = loss(net(features), labels) # 扫完一遍数据之后,把整个特征矩阵带进去,评估效果,算损失值print(f'epoch {epoch + 1}, loss {l:f}')

控制台输出:

epoch 1, loss 0.000251
epoch 2, loss 0.000111
epoch 3, loss 0.000110

我们可以发现,随着轮次的增大,损失值在降低。

相关文章:

1-4 动手学深度学习v2-线性回归的简洁实现-笔记

通过使用深度学习框架来简洁地实现 线性回归模型 生成数据集 import numpy as np import torch from torch.utils import data # 从torch.utils中引入一些处理数据的模块 from d2l import torch as d2ltrue_w torch.tensor([2,-3.4]) true_b 4.2 features, labels d2l.syn…...

SQL如何实现数据表行转列、列转行?

SQL行转列、列转行可以帮助我们更方便地处理数据,生成需要的报表和结果集。本文将介绍在SQL中如何实现数据表地行转列、列转行操作,以及实际应用示例。 这里通过表下面三张表进行举例 SQL创建数据库和数据表 数据表示例数据分别如下: data_…...

【React】redux状态管理、react-redux状态管理高级封装模块化

【React】react组件传参、redux状态管理 一、redux全局状态管理1、redux概述2、redux的组成1.1 State-状态1.2 Action-事件1.3 Reducer1.4 Store 3、redux入门案例1.1 前期准备1.2 构建store1.2.1 在src下新建store文件夹1.2.2 在store文件夹下新建index.ts文件1.2.3 在index.t…...

HAProxy 和负载均衡概念简介

简介 HAProxy,全称高可用代理,是一款流行的开源软件 TCP/HTTP 负载均衡器和代理解决方案,可在 Linux、macOS 和 FreeBSD 上运行。它最常见的用途是通过将工作负载分布到多台服务器(例如 Web、应用程序、数据库)上来提…...

【go】ent操作之CRUD与联表查询

文章目录 1 CRUD1.1 创建1.1.1 单条创建1.1.2 批量创建 1.2 查找1.2.1 查询单条 / 条件准确查询1.2.2 查询单条 / 条件模糊查询1.2.3 查询单条 / In1.2.4 查询全部 1.3 更新1.4 删除 2 联表查询2.1 O2M(一对多查询)2.1.1 增加Edge2.1.2 查询方法2.1.2.1 …...

服务器性能监控管理方法及工具

服务器是组织数据中心的主干,无论是优化的用户体验,还是管理良好的资源,服务器都能为您完成所有工作,保持服务器随时可用和可访问对于面向业务的应用程序和服务以最佳水平运行至关重要。 理想的服务器性能需要主动监控物理和虚拟…...

AUTOSAR汽车电子嵌入式编程精讲300篇-基于FPGA和CAN协议2.0B的总线控制器研究与设计

目录 前言 研究现状分析 2 CAN总线协议 2.1 CAN总线基本概念 2.2 物理层...

14.1 Ajax与JSON应用(❤❤)

14.1 Ajax与JSON应用 1. Ajax1.1 简介1.2 Ajax使用流程1. 前端创建XMLHttpRequest对象2. 发送Ajax请求3. 处理服务器响应4. 代码2. JSON2.1 简介2.2 JS解析JSON3. Ajax与JSON开发3.1 后端:用Jackson实现JSON序列化输出3.2 前端Ajax处理JSON3.3 Ajax工具...

ffmpeg命令生成器

FFmpeg 快速入门:命令行详解、工具、教程、电子书 – 码中人的博客FFmpeg 是一个强大的命令行工具,可以用来处理音频、视频、字幕等多媒体文件。本文介绍了 FFmpeg 的基本用法、一些常用的命令行参数,以及常用的可视化工具。https://blog.mzh…...

JavaScript基础速成

由于学web时只学了后端,现在到了前后端联调的场景发现看不懂前端代码,于是开始恶补 看了下基础内容发现html和css比较好看懂,但JavaScript比较迷,大概知道组件id绑定事件 下面选取看菜鸟教程补充的JS知识 JS的作用 JS是在html…...

openGauss学习笔记-215 openGauss性能调优-确定性能调优范围-性能日志

文章目录 openGauss学习笔记-215 openGauss性能调优-确定性能调优范围-性能日志215.1 性能日志概述215.2 性能日志收集的配置参数 openGauss学习笔记-215 openGauss性能调优-确定性能调优范围-性能日志 215.1 性能日志概述 性能日志主要关注外部资源的访问性能问题。 性能日…...

在vs code的terminal,debug执行python main.py --train True

GPT4告诉我: 在VS Code中以debug状态执行带有参数(如--train)的main.py文件,你需要在launch.json配置文件中正确设置参数。以下是详细步骤: 打开你的main.py文件:确保你的main.py文件已经在VS Code中打开…...

docker 简单项目

要将服务器端口映射到容器端口,你可以使用 Docker 命令的 -p 选项。以下是基本的步骤: 1. **拉取镜像:** 在服务器上运行以下命令拉取你想要的 Docker 镜像,例如 Nginx: bash docker pull nginx 2. **运行容器…...

计算机毕业设计 基于SpringBoot的线上教育培训办公系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…...

四、机器学习基础概念介绍

四、机器学习基础概念介绍 1_机器学习基础概念机器学习分类1.1 有监督学习1.2 无监督学习 2_有监督机器学习—常见评估方法数据集的划分2.1 留出法2.2 校验验证法(重点方法)简单交叉验证K折交叉验证(单独流出测试集)(常…...

Excel设置单元格下拉框(poi)

前言 年关在即,还在最后的迭代处理,还分了个其他同事的单,说是导出的Excel模版的2列要修改为下拉选项,过程很曲折,不说,以下其实就是一个笔记而已! 其实之前分享过阿里的EasyExcel设置单…...

api接口是什么意思,api接口该如何防护呢?

API接口:应用程序与服务之间的接口 什么是API接口 API是应用程序接口的缩写,指的是能够让不同的应用程序之间交换数据的一种方式。一个API接口就是应用程序与服务之间的接口,它定义了服务提供的功能和数据,以及应用程序如何访问这…...

PMP资料怎么学?PMP备考经验分享

PMP考试前大家大多都是提前备考个一两个月,但是有些朋友喜欢“不走寻常路”,并不打算去考PMP认证,想要单纯了解PMP,不管要不要考证,即使是仅仅学习了解一下我个人都非常支持,因为专业的基础的确能提高工作效…...

partition by list(msn_id)子句的含义

在数据库查询中,特别是在使用SQL语言时,"PARTITION BY" 子句用于对结果集进行分区,以便可以对每个分区进行单独的聚合操作。这是在执行窗口函数(如 ROW_NUMBER(), RANK(), SUM(), AVG() 等)时特别有用的。 …...

【C++】I/O多路转接详解(二)

在上一篇文章【C】I/O多路转接详解(一) 在出现EPOLL之后,随之而来的是两种事件处理模式的应运而生:Reator 和 Proactor,同步IO模型常用于Reactor模式,异步IO常用于Proactor. 目录 1. 服务器编程框架简介2. IO处理1. R…...

springboot 百货中心供应链管理系统小程序

一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...

SciencePlots——绘制论文中的图片

文章目录 安装一、风格二、1 资源 安装 # 安装最新版 pip install githttps://github.com/garrettj403/SciencePlots.git# 安装稳定版 pip install SciencePlots一、风格 简单好用的深度学习论文绘图专用工具包–Science Plot 二、 1 资源 论文绘图神器来了:一行…...

三维GIS开发cesium智慧地铁教程(5)Cesium相机控制

一、环境搭建 <script src"../cesium1.99/Build/Cesium/Cesium.js"></script> <link rel"stylesheet" href"../cesium1.99/Build/Cesium/Widgets/widgets.css"> 关键配置点&#xff1a; 路径验证&#xff1a;确保相对路径.…...

java调用dll出现unsatisfiedLinkError以及JNA和JNI的区别

UnsatisfiedLinkError 在对接硬件设备中&#xff0c;我们会遇到使用 java 调用 dll文件 的情况&#xff0c;此时大概率出现UnsatisfiedLinkError链接错误&#xff0c;原因可能有如下几种 类名错误包名错误方法名参数错误使用 JNI 协议调用&#xff0c;结果 dll 未实现 JNI 协…...

Keil 中设置 STM32 Flash 和 RAM 地址详解

文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

从零实现STL哈希容器:unordered_map/unordered_set封装详解

本篇文章是对C学习的STL哈希容器自主实现部分的学习分享 希望也能为你带来些帮助~ 那咱们废话不多说&#xff0c;直接开始吧&#xff01; 一、源码结构分析 1. SGISTL30实现剖析 // hash_set核心结构 template <class Value, class HashFcn, ...> class hash_set {ty…...

【JVM面试篇】高频八股汇总——类加载和类加载器

目录 1. 讲一下类加载过程&#xff1f; 2. Java创建对象的过程&#xff1f; 3. 对象的生命周期&#xff1f; 4. 类加载器有哪些&#xff1f; 5. 双亲委派模型的作用&#xff08;好处&#xff09;&#xff1f; 6. 讲一下类的加载和双亲委派原则&#xff1f; 7. 双亲委派模…...

接口自动化测试:HttpRunner基础

相关文档 HttpRunner V3.x中文文档 HttpRunner 用户指南 使用HttpRunner 3.x实现接口自动化测试 HttpRunner介绍 HttpRunner 是一个开源的 API 测试工具&#xff0c;支持 HTTP(S)/HTTP2/WebSocket/RPC 等网络协议&#xff0c;涵盖接口测试、性能测试、数字体验监测等测试类型…...

nnUNet V2修改网络——暴力替换网络为UNet++

更换前,要用nnUNet V2跑通所用数据集,证明nnUNet V2、数据集、运行环境等没有问题 阅读nnU-Net V2 的 U-Net结构,初步了解要修改的网络,知己知彼,修改起来才能游刃有余。 U-Net存在两个局限,一是网络的最佳深度因应用场景而异,这取决于任务的难度和可用于训练的标注数…...

LCTF液晶可调谐滤波器在多光谱相机捕捉无人机目标检测中的作用

中达瑞和自2005年成立以来&#xff0c;一直在光谱成像领域深度钻研和发展&#xff0c;始终致力于研发高性能、高可靠性的光谱成像相机&#xff0c;为科研院校提供更优的产品和服务。在《低空背景下无人机目标的光谱特征研究及目标检测应用》这篇论文中提到中达瑞和 LCTF 作为多…...