当前位置: 首页 > news >正文

【PyTorch】线性回归

文章目录

  • 1. 模型与代码实现
  • 2. Q&A

1. 模型与代码实现

  • 模型
    y ^ = w 1 x 1 + . . . + w d x d + b = w ⊤ x + b . \hat{y} = w_1 x_1 + ... + w_d x_d + b = \mathbf{w}^\top \mathbf{x} + b. y^=w1x1+...+wdxd+b=wx+b.

  • 代码实现

import torch
from torch import nn
from torch.utils import data
from d2l import torch as d2l# 全局参数设置
batch_size = 10
num_epochs = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 生成数据集
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
features, labels = features.to(device), labels.to(device)# 加载数据集
dataset = data.TensorDataset(features, labels)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True)# 创建神经网络
net = nn.Linear(2, 1).to(device)# 初始化模型参数
nn.init.normal_(net.weight, mean=0, std=0.01)
nn.init.constant_(net.bias, val=0)# 设置损失函数
criterion = nn.MSELoss()# 设置优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.03)# 训练模型
for epoch in range(num_epochs):for X, y in dataloader:X, y = X.to(device), y.to(device)loss = criterion(net(X) ,y)optimizer.zero_grad()loss.backward()optimizer.step()loss = criterion(net(features), labels)print(f'epoch {epoch + 1}, loss {loss:f}')# 评估训练结果
w = net.weight.data.cpu()
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net.bias.data.cpu()
print('b的估计误差:', true_b - b)

输出结果

epoch 1, loss 0.000211
epoch 2, loss 0.000099
epoch 3, loss 0.000099
w的估计误差: tensor([ 4.2558e-04, -5.3167e-05])
b的估计误差: tensor([0.0005])

2. Q&A

  • 如何安装d2l模块?
    pip install d2l
    
  • 为什么模型参数初始化时要将偏置设为0?
    在机器学习中,我们通常会使用非常小的学习率来进行梯度下降,以防止在更新参数时发生剧烈的波动。如果偏置项的初始值不为0,那么在开始训练模型时,可能会因为偏置的影响,导致权重更新的方向出现偏差。
  • 为什么选择均方损失作为线性回归模型的损失函数?
    因为在高斯噪声的假设下,最小化均方误差 ⇔ \Leftrightarrow 对给定 x \mathbf{x} x观测 y y y的极大似然估计 ⇔ w , b \Leftrightarrow \mathbf{w},b w,b取最优值。详细推导。均方误差函数如下: M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 \mathbf{MSE}=\frac{1}{n}\sum_{i=1}^{n}{(y_i-\hat{y_i})^2} MSE=n1i=1n(yiyi^)2其中 y i y_i yi是真实数据, y i ^ \hat{y_i} yi^是拟合数据。
  • 如何理解模型训练过程?
    for epoch in range(num_epochs):for X, y in dataloader:X, y = X.to(device), y.to(device)# 计算网络输出结果与预期结果的误差loss = criterion(net(X) ,y)# 清空参数梯度缓存值,否则梯度会与上一个batch的数据相关optimizer.zero_grad()# 误差反向传播计算参数梯度值loss.backward()# 更新模型参数optimizer.step()# 计算在整个训练集上的误差loss = criterion(net(features), labels)print(f'epoch {epoch + 1}, loss {loss:f}')
    

相关文章:

【PyTorch】线性回归

文章目录 1. 模型与代码实现2. Q&A 1. 模型与代码实现 模型 y ^ w 1 x 1 . . . w d x d b w ⊤ x b . \hat{y} w_1 x_1 ... w_d x_d b \mathbf{w}^\top \mathbf{x} b. y^​w1​x1​...wd​xd​bw⊤xb. 代码实现 import torch from torch import nn from to…...

硝烟弥漫的科技战场——GPT之战

没想到2023年的双11之后,还能看到如此多的科技圈大佬针对GPT提出火药味十足的讨论和极具戏剧性的表演。 历史回顾: 11月6日,OpenAI发布会:GPT-4 Turbo模型、GPT应用商店、开源Whisper-large-v3等;11月17日&#xff0…...

re:Invent 构建未来:云计算生成式 AI 诞生科技新局面

文章目录 前言什么是云计算云计算类型亚马逊云科技云计算最多的功能最大的客户和合作伙伴社区最安全最快的创新速度最成熟的运营专业能力 什么是生成式 AI如何使用生成式 AI后记 前言 在科技发展的滚滚浪潮中,我们见证了云计算的崛起和生成式 AI 的突破&#xff0c…...

oneApi实现并⾏排序算法

零、OneApi简介 oneAPI是由英特尔推出的一个开放、统一的编程模型和工具集合,旨在简化跨不同硬件架构的并行计算。oneAPI的目标是提供一个统一的编程模型,使开发人员能够使用相同的代码在不同类型的硬件上进行并行计算,包括CPU、GPU、FPGA和…...

语音芯片的BUSY状态指示功能特征:提升用户体验与系统稳定性的关键

在电子产品的音频系统中,语音芯片扮演着至关重要的角色。为了保证音频的流畅播放和功能的正常运行,语音芯片的各种状态指示功能变得尤为重要。其中,BUSY状态指示功能是语音芯片中的一项关键特征,它对于提升用户体验和系统稳定性具…...

Leetcode2661. 找出叠涂元素

Every day a Leetcode 题目来源:2661. 找出叠涂元素 解法1:哈希 题目很绕,理解题意后就很简单。 由于矩阵 mat 中每一个元素都不同,并且都在数组 arr 中,所以首先我们用一个哈希表 hash 来存储 mat 中每一个元素的…...

免费最新6款热门SEO优化排名工具

网站的存在感对于业务和品牌的成功至关重要。在众多网站推广方法中,搜索引擎优化(SEO)是提高网站可见性的关键。而SEO的核心之一就是关键词排名。为了更好地帮助您优化网站。 SEO关键词排名工具 在如今信息过载的互联网时代,用户…...

绝地求生在steam叫什么?

绝地求生在Steam的全名是《PlayerUnknowns Battlegrounds》,简称为PUBG。作为一款风靡全球的多人在线游戏,PUBG于2017年3月23日正式上线Steam平台,并迅速成为一部热门游戏。 PUBG以生存竞技为核心玩法,玩家将被投放到一个辽阔的荒…...

Elasticsearch:什么是大语言模型(LLM)?

大语言模型定义 大语言模型 (LLM) 是一种深度学习算法,可以执行各种自然语言处理 (natural language processing - NLP) 任务。 大型语言模型使用 Transformer 模型,并使用大量数据集进行训练 —— 因此规模很大。 这使他们能够识别、翻译、预测或生成文…...

Kubernetes1.27容器化部署Prometheus

Kubernetes1.27容器化部署Prometheus GitHub链接根据自己的k8s版本选择对应的版本修改镜像地址部署命令对Etcd集群进行监控(云原生监控)创建Etcd Service创建Etcd证书的Secret创建Etcd ServiceMonitorgrafana导入模板成功截图 对MySQL进行监控&#xff0…...

fasterxml 注解组装实体

使用 FasterXML Jackson 的注解 JsonTypeInfo 和 JsonSubTypes 可以实现多态类型的处理。在你的 User 类上,你可以添加这些注解来指示 Jackson 如何处理多态类型。 以下是使用 JsonTypeInfo 和 JsonSubTypes 注解的 User 类的修改: import com.fasterx…...

自写一个函数将js对象转为Ts的Interface接口

如今的前端开发typescript 已经成为一项必不可以少的技能了,但是频繁的定义Interface接口会给我带来许多工作量,我想了想如何来减少这些非必要且费时的工作量呢,于是决定写一个函数,将对象放进它自动帮我们转换成Interface接口&am…...

【数据结构】拆分详解 - 二叉树的链式存储结构

文章目录 一、前置说明二、二叉树的遍历  1. 前序、中序以及后序遍历   1.1 前序遍历   1.2 中序遍历   1.3 后序遍历 2. 层序遍历 三、常见接口实现  0. 递归中的分治思想  1. 查找与节点个数   1.1 节点个数   1.2 叶子节点个数   1.3 第k层节…...

Laravel修改默认的auth模块为md5(password+salt)验证

首先声明:这里只是作为一个记录,实行拿来主义,懒得去记录那些分析源码的过程,不喜勿喷,可直接划走。 第一步:创建文件夹:app/Helpers/Hasher; 第二步:创建文件: app/Help…...

OpenStack-train版安装之安装Keystone(认证服务)、Glance(镜像服务)、Placement

安装Keystone(认证服务)、Glance(镜像服务)、Placement 安装Keystone(认证服务)安装Glance(镜像服务)安装Placement 安装Keystone(认证服务) 数据库创建、创…...

【九日集训】第九天:简单递归

递归就是自己调用自己,例如斐波那契数列就是可以用简单递归来实现。 第一题 172. 阶乘后的零 https://leetcode.cn/problems/factorial-trailing-zeroes/description/ 这一题纯粹考数学推理能力,我这种菜鸡看了好久都没有懂。 大概是这样的思路&#x…...

Prime 1.0

信息收集 存活主机探测 arp-scan -l 或者利用nmap nmap -sT --min-rate 10000 192.168.217.133 -oA ./hosts 可以看到存活主机IP地址为:192.168.217.134 端口探测 nmap -sT -p- 192.168.217.134 -oA ./ports UDP端口探测 详细服务等信息探测 开放端口22&#x…...

Java 如何正确比较两个浮点数

看下面这段代码,将 d1 和 d2 两个浮点数进行比较,输出的结果会是什么? double d1 .1 * 3; double d2 .3; System.out.println(d1 d2);按照正常逻辑来看,d1 经过计算之后的结果应该是 0.3,最后打印的结果应该是 tru…...

Qt 如何操作SQLite3数据库?数据库创建和表格的增删改查?

# 前言 项目源码下载 https://gitcode.com/m0_45463480/QSQLite3/tree/main # 第一步 项目配置 平台:windows10 Qt版本:Qt 5.14.2 在.pro添加 QT += sql 需要的头文件 #include <QSqlDatabase>#include <QSqlError>#include <QSqlQuery>#include &…...

【Hadoop】分布式文件系统 HDFS

目录 一、介绍二、HDFS设计原理2.1 HDFS 架构2.2 数据复制复制的实现原理 三、HDFS的特点四、图解HDFS存储原理1. 写过程2. 读过程3. HDFS故障类型和其检测方法故障类型和其检测方法读写故障的处理DataNode 故障处理副本布局策略 一、介绍 HDFS &#xff08;Hadoop Distribute…...

Linux应用开发之网络套接字编程(实例篇)

服务端与客户端单连接 服务端代码 #include <sys/socket.h> #include <sys/types.h> #include <netinet/in.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <arpa/inet.h> #include <pthread.h> …...

基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销&#xff0c;平衡网络负载&#xff0c;延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...

HTML前端开发:JavaScript 常用事件详解

作为前端开发的核心&#xff0c;JavaScript 事件是用户与网页交互的基础。以下是常见事件的详细说明和用法示例&#xff1a; 1. onclick - 点击事件 当元素被单击时触发&#xff08;左键点击&#xff09; button.onclick function() {alert("按钮被点击了&#xff01;&…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)

本文把滑坡位移序列拆开、筛优质因子&#xff0c;再用 CNN-BiLSTM-Attention 来动态预测每个子序列&#xff0c;最后重构出总位移&#xff0c;预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵&#xff08;S…...

群晖NAS如何在虚拟机创建飞牛NAS

套件中心下载安装Virtual Machine Manager 创建虚拟机 配置虚拟机 飞牛官网下载 https://iso.liveupdate.fnnas.com/x86_64/trim/fnos-0.9.2-863.iso 群晖NAS如何在虚拟机创建飞牛NAS - 个人信息分享...

嵌入式学习之系统编程(九)OSI模型、TCP/IP模型、UDP协议网络相关编程(6.3)

目录 一、网络编程--OSI模型 二、网络编程--TCP/IP模型 三、网络接口 四、UDP网络相关编程及主要函数 ​编辑​编辑 UDP的特征 socke函数 bind函数 recvfrom函数&#xff08;接收函数&#xff09; sendto函数&#xff08;发送函数&#xff09; 五、网络编程之 UDP 用…...

解析两阶段提交与三阶段提交的核心差异及MySQL实现方案

引言 在分布式系统的事务处理中&#xff0c;如何保障跨节点数据操作的一致性始终是核心挑战。经典的两阶段提交协议&#xff08;2PC&#xff09;通过准备阶段与提交阶段的协调机制&#xff0c;以同步决策模式确保事务原子性。其改进版本三阶段提交协议&#xff08;3PC&#xf…...

Java并发编程实战 Day 11:并发设计模式

【Java并发编程实战 Day 11】并发设计模式 开篇 这是"Java并发编程实战"系列的第11天&#xff0c;今天我们聚焦于并发设计模式。并发设计模式是解决多线程环境下常见问题的经典解决方案&#xff0c;它们不仅提供了优雅的设计思路&#xff0c;还能显著提升系统的性能…...

echarts使用graphic强行给图增加一个边框(边框根据自己的图形大小设置)- 适用于无法使用dom的样式

pdf-lib https://blog.csdn.net/Shi_haoliu/article/details/148157624?spm1001.2014.3001.5501 为了完成在pdf中导出echarts图&#xff0c;如果边框加在dom上面&#xff0c;pdf-lib导出svg的时候并不会导出边框&#xff0c;所以只能在echarts图上面加边框 grid的边框是在图里…...

游戏开发中常见的战斗数值英文缩写对照表

游戏开发中常见的战斗数值英文缩写对照表 基础属性&#xff08;Basic Attributes&#xff09; 缩写英文全称中文释义常见使用场景HPHit Points / Health Points生命值角色生存状态MPMana Points / Magic Points魔法值技能释放资源SPStamina Points体力值动作消耗资源APAction…...