【PyTorch】线性回归
文章目录
- 1. 模型与代码实现
- 2. Q&A
1. 模型与代码实现
-
模型
y ^ = w 1 x 1 + . . . + w d x d + b = w ⊤ x + b . \hat{y} = w_1 x_1 + ... + w_d x_d + b = \mathbf{w}^\top \mathbf{x} + b. y^=w1x1+...+wdxd+b=w⊤x+b. -
代码实现
import torch
from torch import nn
from torch.utils import data
from d2l import torch as d2l# 全局参数设置
batch_size = 10
num_epochs = 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 生成数据集
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)
features, labels = features.to(device), labels.to(device)# 加载数据集
dataset = data.TensorDataset(features, labels)
dataloader = data.DataLoader(dataset, batch_size, shuffle=True)# 创建神经网络
net = nn.Linear(2, 1).to(device)# 初始化模型参数
nn.init.normal_(net.weight, mean=0, std=0.01)
nn.init.constant_(net.bias, val=0)# 设置损失函数
criterion = nn.MSELoss()# 设置优化器
optimizer = torch.optim.SGD(net.parameters(), lr=0.03)# 训练模型
for epoch in range(num_epochs):for X, y in dataloader:X, y = X.to(device), y.to(device)loss = criterion(net(X) ,y)optimizer.zero_grad()loss.backward()optimizer.step()loss = criterion(net(features), labels)print(f'epoch {epoch + 1}, loss {loss:f}')# 评估训练结果
w = net.weight.data.cpu()
print('w的估计误差:', true_w - w.reshape(true_w.shape))
b = net.bias.data.cpu()
print('b的估计误差:', true_b - b)
输出结果
epoch 1, loss 0.000211
epoch 2, loss 0.000099
epoch 3, loss 0.000099
w的估计误差: tensor([ 4.2558e-04, -5.3167e-05])
b的估计误差: tensor([0.0005])
2. Q&A
- 如何安装d2l模块?
pip install d2l - 为什么模型参数初始化时要将偏置设为0?
在机器学习中,我们通常会使用非常小的学习率来进行梯度下降,以防止在更新参数时发生剧烈的波动。如果偏置项的初始值不为0,那么在开始训练模型时,可能会因为偏置的影响,导致权重更新的方向出现偏差。 - 为什么选择均方损失作为线性回归模型的损失函数?
因为在高斯噪声的假设下,最小化均方误差 ⇔ \Leftrightarrow ⇔对给定 x \mathbf{x} x观测 y y y的极大似然估计 ⇔ w , b \Leftrightarrow \mathbf{w},b ⇔w,b取最优值。详细推导。均方误差函数如下: M S E = 1 n ∑ i = 1 n ( y i − y i ^ ) 2 \mathbf{MSE}=\frac{1}{n}\sum_{i=1}^{n}{(y_i-\hat{y_i})^2} MSE=n1i=1∑n(yi−yi^)2其中 y i y_i yi是真实数据, y i ^ \hat{y_i} yi^是拟合数据。 - 如何理解模型训练过程?
for epoch in range(num_epochs):for X, y in dataloader:X, y = X.to(device), y.to(device)# 计算网络输出结果与预期结果的误差loss = criterion(net(X) ,y)# 清空参数梯度缓存值,否则梯度会与上一个batch的数据相关optimizer.zero_grad()# 误差反向传播计算参数梯度值loss.backward()# 更新模型参数optimizer.step()# 计算在整个训练集上的误差loss = criterion(net(features), labels)print(f'epoch {epoch + 1}, loss {loss:f}')
相关文章:
【PyTorch】线性回归
文章目录 1. 模型与代码实现2. Q&A 1. 模型与代码实现 模型 y ^ w 1 x 1 . . . w d x d b w ⊤ x b . \hat{y} w_1 x_1 ... w_d x_d b \mathbf{w}^\top \mathbf{x} b. y^w1x1...wdxdbw⊤xb. 代码实现 import torch from torch import nn from to…...
硝烟弥漫的科技战场——GPT之战
没想到2023年的双11之后,还能看到如此多的科技圈大佬针对GPT提出火药味十足的讨论和极具戏剧性的表演。 历史回顾: 11月6日,OpenAI发布会:GPT-4 Turbo模型、GPT应用商店、开源Whisper-large-v3等;11月17日࿰…...
re:Invent 构建未来:云计算生成式 AI 诞生科技新局面
文章目录 前言什么是云计算云计算类型亚马逊云科技云计算最多的功能最大的客户和合作伙伴社区最安全最快的创新速度最成熟的运营专业能力 什么是生成式 AI如何使用生成式 AI后记 前言 在科技发展的滚滚浪潮中,我们见证了云计算的崛起和生成式 AI 的突破,…...
oneApi实现并⾏排序算法
零、OneApi简介 oneAPI是由英特尔推出的一个开放、统一的编程模型和工具集合,旨在简化跨不同硬件架构的并行计算。oneAPI的目标是提供一个统一的编程模型,使开发人员能够使用相同的代码在不同类型的硬件上进行并行计算,包括CPU、GPU、FPGA和…...
语音芯片的BUSY状态指示功能特征:提升用户体验与系统稳定性的关键
在电子产品的音频系统中,语音芯片扮演着至关重要的角色。为了保证音频的流畅播放和功能的正常运行,语音芯片的各种状态指示功能变得尤为重要。其中,BUSY状态指示功能是语音芯片中的一项关键特征,它对于提升用户体验和系统稳定性具…...
Leetcode2661. 找出叠涂元素
Every day a Leetcode 题目来源:2661. 找出叠涂元素 解法1:哈希 题目很绕,理解题意后就很简单。 由于矩阵 mat 中每一个元素都不同,并且都在数组 arr 中,所以首先我们用一个哈希表 hash 来存储 mat 中每一个元素的…...
免费最新6款热门SEO优化排名工具
网站的存在感对于业务和品牌的成功至关重要。在众多网站推广方法中,搜索引擎优化(SEO)是提高网站可见性的关键。而SEO的核心之一就是关键词排名。为了更好地帮助您优化网站。 SEO关键词排名工具 在如今信息过载的互联网时代,用户…...
绝地求生在steam叫什么?
绝地求生在Steam的全名是《PlayerUnknowns Battlegrounds》,简称为PUBG。作为一款风靡全球的多人在线游戏,PUBG于2017年3月23日正式上线Steam平台,并迅速成为一部热门游戏。 PUBG以生存竞技为核心玩法,玩家将被投放到一个辽阔的荒…...
Elasticsearch:什么是大语言模型(LLM)?
大语言模型定义 大语言模型 (LLM) 是一种深度学习算法,可以执行各种自然语言处理 (natural language processing - NLP) 任务。 大型语言模型使用 Transformer 模型,并使用大量数据集进行训练 —— 因此规模很大。 这使他们能够识别、翻译、预测或生成文…...
Kubernetes1.27容器化部署Prometheus
Kubernetes1.27容器化部署Prometheus GitHub链接根据自己的k8s版本选择对应的版本修改镜像地址部署命令对Etcd集群进行监控(云原生监控)创建Etcd Service创建Etcd证书的Secret创建Etcd ServiceMonitorgrafana导入模板成功截图 对MySQL进行监控࿰…...
fasterxml 注解组装实体
使用 FasterXML Jackson 的注解 JsonTypeInfo 和 JsonSubTypes 可以实现多态类型的处理。在你的 User 类上,你可以添加这些注解来指示 Jackson 如何处理多态类型。 以下是使用 JsonTypeInfo 和 JsonSubTypes 注解的 User 类的修改: import com.fasterx…...
自写一个函数将js对象转为Ts的Interface接口
如今的前端开发typescript 已经成为一项必不可以少的技能了,但是频繁的定义Interface接口会给我带来许多工作量,我想了想如何来减少这些非必要且费时的工作量呢,于是决定写一个函数,将对象放进它自动帮我们转换成Interface接口&am…...
【数据结构】拆分详解 - 二叉树的链式存储结构
文章目录 一、前置说明二、二叉树的遍历 1. 前序、中序以及后序遍历 1.1 前序遍历 1.2 中序遍历 1.3 后序遍历 2. 层序遍历 三、常见接口实现 0. 递归中的分治思想 1. 查找与节点个数 1.1 节点个数 1.2 叶子节点个数 1.3 第k层节…...
Laravel修改默认的auth模块为md5(password+salt)验证
首先声明:这里只是作为一个记录,实行拿来主义,懒得去记录那些分析源码的过程,不喜勿喷,可直接划走。 第一步:创建文件夹:app/Helpers/Hasher; 第二步:创建文件: app/Help…...
OpenStack-train版安装之安装Keystone(认证服务)、Glance(镜像服务)、Placement
安装Keystone(认证服务)、Glance(镜像服务)、Placement 安装Keystone(认证服务)安装Glance(镜像服务)安装Placement 安装Keystone(认证服务) 数据库创建、创…...
【九日集训】第九天:简单递归
递归就是自己调用自己,例如斐波那契数列就是可以用简单递归来实现。 第一题 172. 阶乘后的零 https://leetcode.cn/problems/factorial-trailing-zeroes/description/ 这一题纯粹考数学推理能力,我这种菜鸡看了好久都没有懂。 大概是这样的思路&#x…...
Prime 1.0
信息收集 存活主机探测 arp-scan -l 或者利用nmap nmap -sT --min-rate 10000 192.168.217.133 -oA ./hosts 可以看到存活主机IP地址为:192.168.217.134 端口探测 nmap -sT -p- 192.168.217.134 -oA ./ports UDP端口探测 详细服务等信息探测 开放端口22&#x…...
Java 如何正确比较两个浮点数
看下面这段代码,将 d1 和 d2 两个浮点数进行比较,输出的结果会是什么? double d1 .1 * 3; double d2 .3; System.out.println(d1 d2);按照正常逻辑来看,d1 经过计算之后的结果应该是 0.3,最后打印的结果应该是 tru…...
Qt 如何操作SQLite3数据库?数据库创建和表格的增删改查?
# 前言 项目源码下载 https://gitcode.com/m0_45463480/QSQLite3/tree/main # 第一步 项目配置 平台:windows10 Qt版本:Qt 5.14.2 在.pro添加 QT += sql 需要的头文件 #include <QSqlDatabase>#include <QSqlError>#include <QSqlQuery>#include &…...
【Hadoop】分布式文件系统 HDFS
目录 一、介绍二、HDFS设计原理2.1 HDFS 架构2.2 数据复制复制的实现原理 三、HDFS的特点四、图解HDFS存储原理1. 写过程2. 读过程3. HDFS故障类型和其检测方法故障类型和其检测方法读写故障的处理DataNode 故障处理副本布局策略 一、介绍 HDFS (Hadoop Distribute…...
华为云AI开发平台ModelArts
华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...
conda相比python好处
Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理:…...
手游刚开服就被攻击怎么办?如何防御DDoS?
开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
Leetcode 3577. Count the Number of Computer Unlocking Permutations
Leetcode 3577. Count the Number of Computer Unlocking Permutations 1. 解题思路2. 代码实现 题目链接:3577. Count the Number of Computer Unlocking Permutations 1. 解题思路 这一题其实就是一个脑筋急转弯,要想要能够将所有的电脑解锁&#x…...
在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module
1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...
使用van-uploader 的UI组件,结合vue2如何实现图片上传组件的封装
以下是基于 vant-ui(适配 Vue2 版本 )实现截图中照片上传预览、删除功能,并封装成可复用组件的完整代码,包含样式和逻辑实现,可直接在 Vue2 项目中使用: 1. 封装的图片上传组件 ImageUploader.vue <te…...
【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)
1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...
以光量子为例,详解量子获取方式
光量子技术获取量子比特可在室温下进行。该方式有望通过与名为硅光子学(silicon photonics)的光波导(optical waveguide)芯片制造技术和光纤等光通信技术相结合来实现量子计算机。量子力学中,光既是波又是粒子。光子本…...
python报错No module named ‘tensorflow.keras‘
是由于不同版本的tensorflow下的keras所在的路径不同,结合所安装的tensorflow的目录结构修改from语句即可。 原语句: from tensorflow.keras.layers import Conv1D, MaxPooling1D, LSTM, Dense 修改后: from tensorflow.python.keras.lay…...
