当前位置: 首页 > news >正文

Pytorch神经网络模型nn.Sequential与nn.Linear

1、定义模型

对于标准深度学习模型,我们可以使用框架的预定义好的层。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。

我们首先定义一个模型变量net,它是一个Sequential类的实例。 Sequential类将多个层串联在一起。 当给定输入数据时,Sequential实例将数据传入到第一层, 然后将第一层的输出作为第二层的输入,以此类推。

在下面的例子中,我们的模型只包含一个层,因此实际上不需要Sequential。 但是由于以后几乎所有的模型都是多层的,在这里使用Sequential会让你熟悉“标准的流水线”。

回顾 图3.1.2中的单层网络架构, 这一单层被称为全连接层(fully-connected layer), 因为它的每一个输入都通过矩阵-向量乘法得到它的每个输出。

在PyTorch中,全连接层在Linear类中定义。 值得注意的是,我们将两个参数传递到nn.Linear中。 第一个指定输入特征形状,即2,第二个指定输出特征形状,输出特征形状为单个标量,因此为1。

import torch# nn是神经网络的缩写
from torch import nnnet = nn.Sequential(nn.Linear(2, 1))

2、初始化模型参数

在使用net之前,我们需要初始化模型参数。 如权重和偏置。 深度学习框架通常有预定义的方法来初始化参数。 在这里,我们指定每个权重参数应该从均值为0、标准差为0.01的正态分布中随机采样, 偏置参数将初始化为零。

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

3、定义损失函数

计算均方误差使用的是MSELoss类,也称为平方L2范数。 默认情况下,它返回所有样本损失的平均值。

loss = nn.MSELoss()

4、定义优化算法

小批量随机梯度下降算法是一种优化神经网络的标准工具, PyTorch在optim模块中实现了该算法的许多变种。 当我们实例化一个SGD实例时,我们要指定优化的参数 (可通过net.parameters()从我们的模型中获得)以及优化算法所需的超参数字典。 小批量随机梯度下降只需要设置lr值,这里设置为0.03。

trainer = torch.optim.SGD(net.parameters(), lr=0.03)

5、训练

通过深度学习框架的高级API来实现我们的模型只需要相对较少的代码。 我们不必单独分配参数、不必定义我们的损失函数,也不必手动实现小批量随机梯度下降。 

在每个迭代周期里,我们将完整遍历一次数据集(train_data), 不停地从中获取一个小批量的输入和相应的标签。 对于每一个小批量,我们会进行以下步骤:

  • 通过调用net(X)生成预测并计算损失l(前向传播)。
  • 通过进行反向传播来计算梯度。
  • 通过调用优化器来更新模型参数

为了更好的衡量训练效果,我们计算每个迭代周期后的损失,并打印它来监控训练过程

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X) ,y)trainer.zero_grad()l.backward()trainer.step()l = loss(net(features), labels)print(f'epoch {epoch + 1}, loss {l:f}')# epoch 1, loss 0.000248
# epoch 2, loss 0.000103
# epoch 3, loss 0.000103

下面比较生成数据集的真实参数通过有限数据训练获得的模型参数。要访问参数,我们首先从net访问所需的层,然后读取该层的权重和偏置。

w = net[0].weight.data
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net[0].bias.data
print('b的估计误差:', true_b - b)# w的估计误差: tensor([-0.0010, -0.0003])
# b的估计误差: tensor([-0.0003])

batchsize的选择和学习率调整

batchsize的选择和学习率调整_batchsize和学习率-CSDN博客

相关文章:

Pytorch神经网络模型nn.Sequential与nn.Linear

1、定义模型 对于标准深度学习模型,我们可以使用框架的预定义好的层。这使我们只需关注使用哪些层来构造模型,而不必关注层的实现细节。 我们首先定义一个模型变量net,它是一个Sequential类的实例。 Sequential类将多个层串联在一起。 当给…...

C++-gdb调试常用功能

文章目录 启动gdb运行程序设置断点运行控制查看源码查看信息查看变量线程相关 gdb调试常用功能如下,其中bin为要调试的程序,arg为参数 启动gdb 启动调试 gdb bin带参数启动 gdb --args bin arg1 arg2so预加载LD_PRELOAD/path/to/lib.so && gdb …...

快速上手的AI工具-文心一言辅助学习

前言 大家好晚上好,现在AI技术的发展,它已经渗透到我们生活的各个层面。对于普通人来说,理解并有效利用AI技术不仅能增强个人竞争力,还能在日常生活中带来便利。无论是提高工作效率,还是优化日常任务,AI工…...

Boost 适用 filesystem 库,statx 函数无法找到引用问题的解决方案。

1、boost 高版本使用了 statx 函数,这个函数是在 Linux 内核版本 4.11 之后引入的。 所以:可以升级 Linux 内核版本到4.11之后即可。 2、降低 boost 库版本到 1.70 以下 3、正确的路,改 boost 的编译代码 先看这个: Filesyste…...

MyBatis中一级缓存是什么?SqlSession一级缓存失效的原因?如何理解一级缓存?

一级缓存是SqlSession级别的,通过同一个SqlSession查询的数据会被缓存,下次查询相同的数据,就 会从缓存中直接获取,不会从数据库重新访问 使一级缓存失效的四种情况: 1) 不同的SqlSession对应不同的一级缓存 2) 同一…...

项目解决方案:多地医馆的高清视频监控接入汇聚联网

目 录 一、背景 二、建设目标及需求 1.建设目标 2.现状分析 3.需求分析 三、方案设计 1.设计依据 2.设计原则 3.方案设计 3.1 方案描述 3.2 组网说明 四、产品介绍 1.视频监控综合资源管理平台介绍 2.视频录像服务器和存储 2.1概述 2.2存储设计 …...

【前端基础--2】

选择器优先级 style标签中&#xff1a; .text{color: pink;}div{color: red;}#box{color: skyblue;} body标签中&#xff1a; <div class"text" id"box">猜猜我是什么颜色的</div> 运行结果&#xff1a; 选择器优先级权重&#xff1a; id选…...

【GitHub项目推荐--提取文字】【转载】

提取视频中的字幕 这个开源项目是提取视频中字幕的开源项目&#xff0c;提取视频中的关键帧&#xff0c;检测视频帧中文本的所在位置&#xff0c;识别视频帧中文本的内容。 不知道大家有没有做笔记的习惯&#xff0c;这个开源项目就很方便的把你一个视频中的字幕提取出来&…...

WebSocket与Shiro认证信息传递的实现与安全性探讨

在现代Web应用程序中&#xff0c;WebSocket已经成为实时双向通信的重要组件。而Shiro作为一个强大的Java安全框架&#xff0c;用于处理身份验证、授权和会话管理。本文将探讨如何通过WebSocket与Shiro集成&#xff0c;实现认证信息的传递&#xff0c;并关注在这一过程中确保安全…...

QT 实现自动生成小学两位数加减法算式

小学生加减法训练 QT实现–自动生成两位数加减法算式&#xff0c;并输出txt文件 可以copy到word文件&#xff0c;设置适当字体大小和行间距&#xff0c;带回家给娃做做题 void MainWindow::test(int answerMax, int count) {// 创建一个随机数生成器QRandomGenerator *gener…...

小程序学习-20

建议每次构建npm之前都先删除miniprogram_npm...

面试题-【消息队列】

消息队列 问题1 如何进行消息队列的技术选型优点解耦 &#xff08;pub/sub模型&#xff09;异步&#xff08;异步接口性能优化&#xff09;削峰 使用消息队列的缺点几种消息队列的特性 问题2 引入消息队列之后该如何保证其高可用性RabbitMQ的高可用kafka高可用 问题3 在消息队列…...

【江科大】STM32:I2C通信外设(硬件)

在将2C通信外设之前&#xff0c;我们先捋一捋&#xff0c;串口的相关特点来和I2C进行一个对北比。 首先&#xff1a; 1,大部分单片机&#xff0c;设计的PCB板均带有串口通信的引脚&#xff08;也就是通信基本都借助硬件收发器来实现&#xff09; 2.对于串口的异步时序&#xff…...

【机器学习300问】15、什么是逻辑回归模型?

一、逻辑回归模型是为了解决什么问题&#xff1f; 逻辑回归&#xff08;Logistic Regression&#xff09;是一种广义线性回归分析模型&#xff0c;尤其适用于解决二分类问题&#xff08;输出为两个类别&#xff09;。 &#xff08;1&#xff09;二分类举例 邮件过滤&#xff…...

C#调用C动态链接库

前言 已经没写过博客好久了&#xff0c;上一篇还是1年半前写的LTE Gold序列学习笔记&#xff0c;因为工作是做通信协议的&#xff0c;然后因为大学时没好好学习专业课&#xff0c;现在理论还不扎实&#xff0c;不敢瞎写&#xff1b; 因为工作原因&#xff0c;经常需要分析一些字…...

前端实现转盘抽奖 - 使用 lucky-canvas 插件

目录 需求背景需求实现实现过程图片示意实现代码 页面效果lucky-canvas 插件官方文档 需求背景 要求实现转盘转动抽奖的功能&#xff1a; 只有正确率大于等于 80% 才可以进行抽奖&#xff1b;“谢谢参与”概率为 90%&#xff0c;“恭喜中奖”概率为 10%&#xff1b; 需求实现 实…...

2024.1.23力扣每日一题——最长交替子数组

2024.1.23 题目来源我的题解方法一 枚举 题目来源 力扣每日一题&#xff1b;题序&#xff1a;2765 我的题解 方法一 枚举 每次都以两个相邻作为满足要求的循环数据&#xff0c;并且以一个布尔变量控制循环的位置 时间复杂度&#xff1a;O(n) 空间复杂度&#xff1a;O(1) pub…...

C语言王道练习题第七周两题

第一题 Description 输入一个学生的学号&#xff0c;姓名&#xff0c;性别&#xff0c;用结构体存储&#xff0c;通过 scanf 读取后&#xff0c;然后再 通过 printf 打印输出 Input 学号&#xff0c;姓名&#xff0c;性别&#xff0c;例如输入 101 xiongda m Output 输出…...

某马头条——day11+day12

实时计算和定时计算 流式计算 kafkaStream 入门案例 导入依赖 <dependency><groupId>org.apache.kafka</groupId><artifactId>kafka-streams</artifactId><exclusions><exclusion><artifactId>connect-json</artifactId&…...

springboot实现aop

目录 AOP(术语)引入依赖实现步骤测试验证感谢阅读 AOP(术语) 连接点 类里面哪些方法可以增强&#xff0c;这些点被称为连接点 切入点 实际被真正增强的方法 通知&#xff08;增强&#xff09; 实际增强的逻辑部分称为通知&#xff08;增强&#xff09; 通知&#xff08;增强&…...

从CANoe实战出发:深度解析UDS网络层诊断中的流控帧(FC)与时间参数STmin

从CANoe实战解析UDS流控帧&#xff1a;FC与STmin参数调优指南 在汽车电子测试领域&#xff0c;UDS诊断协议的网络层流控机制直接影响着ECU通信的可靠性与效率。当测试工程师在CANoe环境中模拟诊断会话时&#xff0c;经常会遇到因流控帧参数配置不当导致的报文丢失、响应超时等问…...

VMDE终极指南:如何快速检测虚拟机环境的完整教程

VMDE终极指南&#xff1a;如何快速检测虚拟机环境的完整教程 【免费下载链接】VMDE Source from VMDE paper, adapted to 2015 项目地址: https://gitcode.com/gh_mirrors/vm/VMDE VMDE&#xff08;Virtual Machine Detection Enhanced&#xff09;是一款强大的开源虚拟…...

Fabric 结合IPFS 链码示例

购买专栏前请认真阅读:《Fabric项目学习笔记》专栏介绍 package mainimport ("bytes""encoding/json""fmt""time""github.com/hyperledger/fabric/core/chaincode/shim"sc "github.com/hyperledger/fabric/protos/pee…...

Diem存储协议终极指南:如何构建高性能分布式文件存储系统

Diem存储协议终极指南&#xff1a;如何构建高性能分布式文件存储系统 【免费下载链接】diem Diem’s mission is to build a trusted and innovative financial network that empowers people and businesses around the world. 项目地址: https://gitcode.com/gh_mirrors/di…...

AI 驱动单元测试生成:智能优先级与自动化验证实践

1. 项目概述如果你和我一样&#xff0c;长期在维护一个中大型的 TypeScript 项目&#xff0c;那么“补单元测试”这件事&#xff0c;大概率是你技术债清单上那个永远在滚动、却很少被真正划掉的任务。手动写测试枯燥耗时&#xff0c;尤其是面对那些遗留的、逻辑复杂的业务函数时…...

从文献检索到论文写作:Perplexity与Zotero构建AI-native科研流水线(实测单篇综述效率提升3.8倍)

更多请点击&#xff1a; https://intelliparadigm.com 第一章&#xff1a;从文献检索到论文写作&#xff1a;Perplexity与Zotero构建AI-native科研流水线&#xff08;实测单篇综述效率提升3.8倍&#xff09; 在AI-native科研范式下&#xff0c;传统文献管理与写作流程正被重构…...

ARM MPMC内存控制器架构与优化策略

1. ARM MPMC内存控制器架构解析在嵌入式系统设计中&#xff0c;内存控制器作为处理器与存储设备之间的桥梁&#xff0c;其性能直接影响整个系统的运行效率。ARM PrimeCell多端口内存控制器(MPMC)是一种高度可配置的IP核&#xff0c;支持与多种类型存储设备的连接&#xff0c;包…...

Ante语言:精化类型与生命周期推断在系统编程中的实践探索

1. 项目概述&#xff1a;Ante&#xff0c;一个探索系统编程新范式的语言 最近在关注系统级编程语言的发展&#xff0c;发现了一个很有意思的项目&#xff1a;Ante。这并非一个成熟的生产级工具&#xff0c;而更像是一个充满野心的“实验室”。它的核心目标&#xff0c;是尝试将…...

LangChain+FAISS 向量数据库搭建轻量化 RAG 应用

&#x1f4dd; 本章学习目标&#xff1a;本章聚焦企业轻量化落地&#xff0c;帮助读者快速掌握基于 LangChainFAISS 的私有化 RAG 开发流程。通过本章学习&#xff0c;你将从零搭建一套无需 GPU、无外网依赖、纯本地运行、代码极简、可直接上线的轻量化 RAG 应用。 一、引言&a…...

Pega Helm Charts:Kubernetes上自动化部署Pega平台的完整指南

1. 项目概述与核心价值如果你正在或即将在Kubernetes上部署Pega Platform&#xff0c;那么pegasystems/pega-helm-charts这个项目绝对是你绕不开的“官方说明书”和“自动化工具箱”。简单来说&#xff0c;这是Pega官方维护的一套Helm Chart&#xff0c;专门用于将Pega Platfor…...