当前位置: 首页 > news >正文

深度学习——线性神经网络(三、线性回归的简洁实现)

目录

  • 3.1 生成数据集
  • 3.2 读取数据集
  • 3.3 定义模型
  • 3.4 初始化模型参数
  • 3.5 定义损失函数
  • 3.6 定义优化算法
  • 3.7 训练

  在上一节中,我们通过张量来自定义式地进行数据存储和线性代数运算,并通过自动微分来计算梯度。实际上,由于数据迭代器、损失函数、优化器和神经网络层很常用,现代深度学习框架已经为我们实现了这些组件,只需要调用即可。

3.1 生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
# 可以打印出来看一下
print(features,labels)

在这里插入图片描述

3.2 读取数据集

  我们可以通过调用框架中现有的API来读取数据,将features和labels作为API的参数传递,并通过数据迭代器指定batch_size,此外,布尔值is_train表示是否希望数据迭代器对象在每轮内打乱数据。

def load_array(data_arrays, batch_size, is_train=True):"""构造一个Python数据迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features,labels), batch_size)

  提到的data.TensorDataset(*data_arrays)中,*号的用法与函数定义中的类似,它表示TensorDataset可以接受任意数量的参数。这些参数通常是torch.Tensor对象,其中最后一个参数默认被视为标签,其余的参数被视为特征。

  使用iter函数构造Python迭代器,并使用next函数从迭代器中获取第一项。

print(next(iter(data_iter)))

  我是用pycharm写的代码,和jupyter中有些不一样,jupyter中直接写next(iter(data_iter))就可以打印出来了,pycharm中必须要加上print

在这里插入图片描述

  因为布尔值shuffle=is_train表示数据迭代器对象在每轮内打乱数据,所以next函数取出来的第一批量10项数据,并不直接是生成的数据集中的前10项数据。这点大家可以注意一下!

3.3 定义模型

  对于标准深度学习模型,我们可以使用框架已经预定义好的层,这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。
  我们先定义一个模型变量net,它是一个Sequential类的实例。Sequential类将多个层串联在一起,当给定输入数据时,Sequential实例将数据传入第一层,然后将第一层的输出作为第二层的输入,以此类推。
  在线性神经网络中,模型只包含一个层,因此实际上不需要Sequential,但是由于以后几乎所有的模型都是多个层的,在这里使用Sequential类更方便理解“标准的流水线”。
在这里插入图片描述
  在单层网络架构中,这一单层称为“全连接层”,因为它的每个输入都通过矩阵-向量乘法得到它的每个输出。
  在pytorch中,全连接层在Linear类中定义,我们将两个参数传递到nn.Linear中,第一个参数指定输入特征的形状,即2;第二个参数指定输出特征形状,输出特征形状为单个标量,因此为1。

# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2,1))

3.4 初始化模型参数

  在使用net之前,我们需要初始化模型参数,如在线性回归模型中的权重和偏置。深度学习框架通常由预定义的方法来初始化参数。
  在这里,我们指定每个权重系数应该从均值为0,标准差为0.01的正态分布中随机抽样,偏置参数将初始化为0.
  我们在构造nn.Linear时指定了输入和输出的尺寸,现在可以直接访问参数以设定它们的初始值。通过net[0]选择网络中的第一层,然后使用weight.data和bias.data方法访问函数。我们还可以使用替换方法normal_和fill_来重写参数值。

# 重写参数值之前的对比
print(net[0].weight.data)
print(net[0].bias.data)net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
print(net[0].weight.data)
print(net[0].bias.data)

  下面是重写参数值之前的对比在这里插入图片描述

3.5 定义损失函数

  计算均方误差使用的是MSELoss类,也称为平方 L 2 L_2 L2范数。默认情况下,它返回所有样本损失的平均值

loss = nn.MSELoss()

3.6 定义优化算法

  小批量随机梯度下降算法是一种优化神经网络的标准工具,Pytorch在optim模块中实现了该算法的许多变体。当我们实例化一个SGD实例时,我们要指定优化的参数(可以通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。小批量随机梯度下降只需要设置lr的值,这里设置为0.03.

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

3.7 训练

  在每轮里,我们将完整遍历一次数据集(train_data),不断地从中获取一个小批量的输入和相应的标签。对于每个小批量,将执行以下步骤:

  • 通过调用net(X)生成预测并计算损失l(前向传播)
  • 通过反向传播来计算梯度
  • 通过调用优化器来更新模型参数

  为了 更好地度量训练效果,我们计算每轮后的损失,并打印出来监控训练过程。

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features),labels)print(f'epoch{epoch + 1}, loss {l:f}')

在这里插入图片描述

几点注意:
l = loss(net(X), y)
loss函数中已经有了sum()操作,省略了原来实现过程中的 l.sum() 这一步骤
net(X)
net()本身就带了模型中的参数,就不需要把W,b写进去了
trainer.zero_grad()
优化器需要先把梯度清零
trainer.step()
调用step()函数进行模型更新
l = loss(net(features),labels)
模型参数更新完之后,再计算一遍均方误差

  下面比较一下生成数据集的真实参数和通过有限数据训练获得的模型参数。要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。如下所示,我们估计得到的参数与生成数据集的真实参数非常接近。

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

在这里插入图片描述

小结:

  • 我们可以使用Pytorch中的高级API更简洁地实现模型;
  • 在Pytorch中,data模块提供了数据处理工具,nn 模块定义了大量的神经网络层和常见的损失函数;
  • 我们可以通过以"_"结尾的方法将参数替换,从而自定义初始化参数。

以下是完整代码:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
# nn是神经网络的缩写
from torch import nntrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b,1000)
# 可以打印出来看一下
# print(features,labels)def load_array(data_arrays, batch_size, is_train=True):"""构造一个Python数据迭代器"""dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features,labels), batch_size)
# print(next(iter(data_iter)))net = nn.Sequential(nn.Linear(2,1))# 重写参数值之前的对比
# print(net[0].weight.data)
# print(net[0].bias.data)net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)
# print(net[0].weight.data)
# print(net[0].bias.data)
loss = nn.MSELoss()trainer = torch.optim.SGD(net.parameters(), lr=0.03)num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features),labels)# print(f'epoch{epoch + 1}, loss {l:f}')w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)

相关文章:

深度学习——线性神经网络(三、线性回归的简洁实现)

目录 3.1 生成数据集3.2 读取数据集3.3 定义模型3.4 初始化模型参数3.5 定义损失函数3.6 定义优化算法3.7 训练 在上一节中,我们通过张量来自定义式地进行数据存储和线性代数运算,并通过自动微分来计算梯度。实际上,由于数据迭代器、损失函数…...

本地部署 Milvus

本地部署 Milvus 1. Install Milvus in Docker2. Install Attu, an open-source GUI tool 1. Install Milvus in Docker curl -sfL https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh -o standalone_embed.shbash standalone_embed.sh …...

Git基础-配置http链接的免密登录

问题描述 当我们在使用 git pull 或者 git push 进行代码拉取或代码提交时, 若我们的远程代码仓库是 http协议的链接时,就是就会提示我们进行账号密码的登录。 每次都要登录,这未免有些麻烦。 本文介绍一下免密登录的配置。解决方案 1 执行…...

华为OD机试真题-编码能力提升-2024年OD统一考试(E卷)

最新华为OD机试考点合集:华为OD机试2024年真题题库(E卷+D卷+C卷)_华为od机试题库-CSDN博客 每一题都含有详细的解题思路和代码注释,精编c++、JAVA、Python三种语言解法。帮助每一位考生轻松、高效刷题。订阅后永久可看,持续跟新。 题目描述 为了提升软件编码能力,小…...

高被引算法GOA优化VMD,结合Transformer-SVM的轴承诊断,保姆级教程!

本期采用2023年瞪羚优化算法优化VMD,并结合Transformer-SVM实现轴承诊断,算是一个小创新方法了。需要水论文的童鞋尽快! 瞪羚优化算法之前推荐过,该成果于2023年发表在计算机领域三区SCI期刊“Neural Computing and Applications”…...

半小时速通RHCSA

1-7章: #01创建以上目录和文件结构,并将/yasuo目录拷贝4份到/目录下 #02查看系统合法shell #03查看系统发行版版本 #04查看系统内核版本 #05临时修改主机名 #06查看系统指令的查找路径 #07查看passwd指令的执行路径 #08为/yasuo/ssh_config文件在/mulu目录下创建软链…...

人工智能和机器学习之线性代数(一)

人工智能和机器学习之线性代数(一) 人工智能和机器学习之线性代数一将介绍向量和矩阵的基础知识以及开源的机器学习框架PyTorch。 文章目录 人工智能和机器学习之线性代数(一)基本定义标量(Scalar)向量&a…...

STM32外设应用详解

STM32外设应用详解 STM32微控制器是意法半导体(STMicroelectronics)推出的一系列基于ARM Cortex-M内核的高性能、低功耗32位微控制器。它们拥有丰富的外设接口和功能模块,可以满足各种嵌入式应用需求。本文将详细介绍STM32的外设及其应用&am…...

docker详解介绍+基础操作 (三)优化配置

1.docker 存储引擎 Overlay: 一种Union FS文件系统,Linux 内核3.18后支持 Overlay2:Overlay的升级版,docker的默认存储引擎,需要磁盘分区支持d-type功能,因此需要系统磁盘的额外支持。 关于 d-type 传送…...

细说Qt的状态机框架及其用法

文章目录 使用场景基本用法状态定义添加转换历史状态QStateMachine是Qt框架中用于构建状态机的一个类,它属于Qt的状态机框架(State Machine Framework)。这个框架提供了一种模型,用于设计响应不同事件(如用户输入、文件I/O或网络活动)的应用程序的行为。通过使用状态机,开发…...

Oracle-表空间与数据文件操作

目录 1、表空间创建 2、表空间修改 3、数据文件可用性切换操作 4、数据文件和表空间删除 1、表空间创建 (1)为 ORCL 数据库创建一个名为 BOOKTBS1 的永久表空间,数据文件为d:\bt01.dbf ,大小为100M,区采用自动扩展…...

C# WinForm实现画笔签名及解决MemoryBmp格式问题

目录 需求 实现效果 开发运行环境 设计实现 界面布局 初始化 画笔绘图 清空画布 导出位图数据 小结 需求 我的文章 《C# 结合JavaScript实现手写板签名并上传到服务器》主要介绍了 web 版的需求实现,本文应项目需求介绍如何通过 C# WinForm 通过画布画笔…...

GC1272替代APX9172/茂达中可应用于电脑散热风扇应用分析

在电脑散热风扇应用中,选择合适的驱动器件对于风扇的性能和效率至关重要。以下是对GC1272替代APX9172/茂达在此类应用中的分析: 1. 功能比较 GC1272: 主要用于驱动直流风扇,具有高效的电流控制和调速功能。支持PWM调速&#xff0…...

《Linux从小白到高手》综合应用篇:详解Linux系统调优之服务器硬件优化

List item 本篇介绍Linux服务器硬件调优。硬件调优主要包括CPU、内存、磁盘、网络等关键硬件组。 1. CPU优化 选择适合的CPU: –根据应用需求选择多核、高频的CPU,以满足高并发和计算密集型任务的需求。CPU缓存优化: –确保CPU缓存&#x…...

PHP政务招商系统——高效连接共筑发展蓝图

政务招商系统——高效连接,共筑发展蓝图 🏛️ 一、政务招商系统:开启智慧招商新篇章 在当今经济全球化的背景下,政务招商成为了推动地方经济发展的重要引擎。而政务招商系统的出现,更是为这一进程注入了新的活力。它…...

Linux 命令行

这学期是我第一次正式学习 linux ,是在 VMware 里创建了 openEuler 的虚拟机练习 linux 的常用命令。 目前主要在学习 linux 的常用命令,因此这篇博客主要介绍一些常用的命令。 本文将持续更新… 阅读建议 Linux 是一个倒置的树结构(文件系…...

每日一题:单例模式

每日一题:单例模式 ❝ 单例模式是确保一个类只有一个实例,并提供一个全局访问点 1.饿汉式(静态常量) 特点:在类加载时就创建了实例。优点:简单易懂,线程安全。缺点:无论是否使用&…...

前端_001_html扫盲

文章目录 概念标签及属性常用全局属性head里常用标签body里常用标签表情符号 url编码 概念 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> </head> <body></bod…...

49 | 桥接模式:如何实现支持不同类型和渠道的消息推送系统?

上一篇文章我们学习了第一种结构型模式&#xff1a;代理模式。它在不改变原始类&#xff08;或者叫被代理类&#xff09;代码的情况下&#xff0c;通过引入代理类来给原始类附加功能。代理模式在平时的开发经常被用到&#xff0c;常用在业务系统中开发一些非功能性需求&#xf…...

使用js和canvas实现简单的网页贪吃蛇小游戏

玩法介绍 点击开始游戏后&#xff0c;使用键盘上的↑↓←→控制移动&#xff0c;吃到食物增加长度&#xff0c;碰到墙壁或碰到自身就游戏结束 代码实现 代码比较简单&#xff0c;直接阅读注释即可&#xff0c;复制即用 <!DOCTYPE html> <html lang"en"…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展&#xff1a;显示创建时间8. 功能扩展&#xff1a;记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

React Native 导航系统实战(React Navigation)

导航系统实战&#xff08;React Navigation&#xff09; React Navigation 是 React Native 应用中最常用的导航库之一&#xff0c;它提供了多种导航模式&#xff0c;如堆栈导航&#xff08;Stack Navigator&#xff09;、标签导航&#xff08;Tab Navigator&#xff09;和抽屉…...

【Java学习笔记】Arrays类

Arrays 类 1. 导入包&#xff1a;import java.util.Arrays 2. 常用方法一览表 方法描述Arrays.toString()返回数组的字符串形式Arrays.sort()排序&#xff08;自然排序和定制排序&#xff09;Arrays.binarySearch()通过二分搜索法进行查找&#xff08;前提&#xff1a;数组是…...

基于Uniapp开发HarmonyOS 5.0旅游应用技术实践

一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架&#xff0c;支持"一次开发&#xff0c;多端部署"&#xff0c;可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务&#xff0c;为旅游应用带来&#xf…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发&#xff0c;后来由Pivotal Software Inc.&#xff08;现为VMware子公司&#xff09;接管。RabbitMQ 是一个开源的消息代理和队列服务器&#xff0c;用 Erlang 语言编写。广泛应用于各种分布…...

腾讯云V3签名

想要接入腾讯云的Api&#xff0c;必然先按其文档计算出所要求的签名。 之前也调用过腾讯云的接口&#xff0c;但总是卡在签名这一步&#xff0c;最后放弃选择SDK&#xff0c;这次终于自己代码实现。 可能腾讯云翻新了接口文档&#xff0c;现在阅读起来&#xff0c;清晰了很多&…...

抽象类和接口(全)

一、抽象类 1.概念&#xff1a;如果⼀个类中没有包含⾜够的信息来描绘⼀个具体的对象&#xff0c;这样的类就是抽象类。 像是没有实际⼯作的⽅法,我们可以把它设计成⼀个抽象⽅法&#xff0c;包含抽象⽅法的类我们称为抽象类。 2.语法 在Java中&#xff0c;⼀个类如果被 abs…...

redis和redission的区别

Redis 和 Redisson 是两个密切相关但又本质不同的技术&#xff0c;它们扮演着完全不同的角色&#xff1a; Redis: 内存数据库/数据结构存储 本质&#xff1a; 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能&#xff1a; 提供丰…...