当前位置: 首页 > news >正文

pytorch学习---实现线性回归初体验

假设我们的基础模型就是y = wx + b,其中w和b均为参数,我们使用y = 3x+0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。

步骤如下:

  1. 准备数据
  2. 计算预测值
  3. 计算损失,把参数的梯度置为0,进行反向传播
  4. 更新参数

方式一

该方式没有用pytorch的模型api,手动实现

import torch,numpy
import matplotlib.pyplot as plt# 1、准备数据
learning_rate = 0.01
#y=3x + 0.8
x = torch.rand([500,1])
y_true= x*3 + 0.8# 2、通过模型计算y_predict
w = torch.rand([1,1],requires_grad=True)
b = torch.tensor(0,requires_grad=True,dtype=torch.float32)# 3、通过循环,反向传播,更新参数
for i in range(500):# 4、计算lossy_predict = torch.matmul(x,w) + bloss = (y_true-y_predict).pow(2).mean()# 每次循环判断是否存在梯度,防止累加if w.grad is not None:w.grad.data.zero_()if b.grad is not None:b.grad.data.zero_()# 反向传播loss.backward()w.data = w.data - learning_rate*w.gradb.data = b.data - learning_rate*b.grad# 每50次输出一下结果if i%50==0:print("w,b,loss",w.item(),b.item(),loss.item())#可视化显示
plt.figure(figsize=(20,8))
plt.scatter(x.numpy().reshape(-1),y_true.numpy().reshape(-1))
y_predict = torch.matmul(x,w) + b
plt.plot(x.numpy().reshape(-1),y_predict.detach().numpy().reshape(-1),c="r")
plt.show()

循环500次的效果
在这里插入图片描述
循环2000次的结果
在这里插入图片描述

方式二

方式一的方式虽然已经购简便了,但是还是有些许繁琐,所以我们可以采用pytorchapi来实现。
nn.Moduletorch.nn提供的一个类,是pytorch中我们自定义网络的一个基类,在这个类中定义了很多有用的方法,让我们在继承这个类定义网络的时候非常简单。
当我们自定义网络的时候,有两个方法需要特别注意:
1.__init__需要调用super方法,继承父类的属性和方法
2. forward方法必须实现,用来定义我们的网络的向前计算的过程用前面的y = wx+b的模型举例如下:

#定义模型
from torch import nn
class Lr(nn.Module): #继承nn.Moduledef __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1,1)def forward(self,x):out = self.linear(x)return out

全部代码如下:

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt#1、定义数据
x = torch.rand([50,1])
y = x*3 + 0.8#定义模型
class Lr(nn.Module): #继承nn.Moduledef __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1,1)def forward(self,x):out = self.linear(x)return out
#2、实例化模型、loss函数以及优化器
model = Lr()
criterion = nn.MSELoss()   #损失函数
optimizer = optim.SGD(model.parameters(),lr=1e-3) #优化器#3、训练模型
for i in range(3000):out = model(x)# 获取预测值loss = criterion(y,out) #计算损失optimizer.zero_grad() #梯度归零loss.backward() #计算梯度optimizer.step() #更新梯度if(i+1) % 20 ==0:print('Epoch[{}/{}],loss:{:.6f}'.format(i,500,loss.data))#4、模型评估
model.eval() #设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

注意:

model.eval()表示设置模型为评估模式,即预测模式

model.train(mode=True) 表示设置模型为训练模式

在当前的线性回归中,上述并无区别

但是在其他的一些模型中,训练的参数和预测的参数会不相同,到时候就需要具体告诉程序我们是在进行训练还是预测,比如模型中存在DropoutBatchNorm的时候

循环2000次的结果:
在这里插入图片描述
循环30000次的结果:
在这里插入图片描述

相关文章:

pytorch学习---实现线性回归初体验

假设我们的基础模型就是y wx b,其中w和b均为参数,我们使用y 3x0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。 步骤如下: 准备数据计算预测值计算损失,把参数的梯度置为0,进行反向传播…...

别再乱写git commit了

B站|公众号:啥都会一点的研究生 写在前面 在很长的一段时间中,使用git commit都是随心所欲,log肥肠简洁,随着代码的迭代,当时有多偷懒,返过头查看git日志就有多懊悔,就和写代码不写doc string…...

八大排序(一)冒泡排序,选择排序,插入排序,希尔排序

一、冒泡排序 冒泡排序的原理是:从左到右,相邻元素进行比较。每次比较一轮,就会找到序列中最大的一个或最小的一个。这个数就会从序列的最右边冒出来。 以从小到大排序为例,第一轮比较后,所有数中最大的那个数就会浮…...

泊松分布简要介绍

泊松分布是一种常见的离散概率分布,它用于描述某个时间段或区域内随机事件发生的次数。它得名于法国数学家西蒙丹尼泊松。 泊松分布的概率质量函数表示某个时间段或区域内事件发生次数的概率。如果随机变量 X 服从泊松分布,记作 X ~ Poisson(λ)&#x…...

C语言每日一题(10):无人生还

文章主题:无人生还🔥所属专栏:C语言每日一题📗作者简介:每天不定时更新C语言的小白一枚,记录分享自己每天的所思所想😄🎶个人主页:[₽]的个人主页🏄&#x1f…...

VSCode开发go手记

断点调试: 安装delve(windows): go get -u github.com/go-delve/delve/cmd/dlv 设置 launch.json 配置文件: ctrlshiftp 输入 Debug: Open launch.json 打开 launch.json 文件,如果第一次打开,会新建一…...

怎么选择AI伪原创工具-AI伪原创工具有哪些

在数字时代,创作和发布内容已经成为了一种不可或缺的活动。不论您是个人博主、企业家还是网站管理员,都会面临一个共同的挑战:如何在互联网上脱颖而出,吸引更多的读者和访客。而正是在这个背景下,AI伪原创工具逐渐崭露…...

【块状链表C++】文本编辑器(指针中 引用 的使用)

》》》算法竞赛 /*** file * author jUicE_g2R(qq:3406291309)————彬(bin-必应)* 一个某双流一大学通信与信息专业大二在读 * * brief 一直在竞赛算法学习的路上* * copyright 2023.9* COPYRIGHT 原创技术笔记:转载…...

echarts的Y轴设置为整数

场景:使用echarts,设置Y轴为整数。通过判断Y轴的数值为整数才显示即可 yAxis: [{name: ,type: value,min: 0, // 最小值// max: 200, // 最大值// splitNumber: 5, // 坐标轴的分割段数// interval: 100 / 5, // 强制设置坐标轴分割间隔度(取本Y轴的最大…...

恢复删除文件?不得不掌握的4个方法!

“删除了的文件还可以恢复吗?有个文件我本来以为不重要了,就把它删除了,没想到现在还需要用到!这可怎么办?有没有办法找回来呢?” 重要的文件一旦丢失或误删可能都会对我们的工作和学习造成比较大的影响。怎…...

GitLab CI/CD:.gitlab-ci.yml 文件常用参数小结

文章目录 一、.gitlab-ci.yml 文件作用二、一个简单的.gitlab-ci.yml 文件示例参考 一、.gitlab-ci.yml 文件作用 可以定义跑CI时想要运行的命令或脚本 可以定义job之间的依赖和缓存 可以执行程序部署并定义部署位置 可以定义想要包含的其他配置文件和模版 二、一个简单的.gi…...

MySQL学习笔记9

MySQL数据表中的数据类型: 在考虑数据类型、长度、标度和精度时,一定要仔细地进行短期和长远的规划,另外,公司制度和希望用户用什么方式访问数据也是要考虑的因素。开发人员应该了解数据的本质,以及数据在数据库里是如…...

从零学习开发一个RISC-V操作系统(三)丨嵌入式操作系统开发的常用概念和工具

本篇文章的内容 一、嵌入式操作习系统开发的常用概念和工具1.1 本地编译和交叉编译1.2 调试器GDB(The GNU Project Debugger)1.3 QEMU模拟器1.4 项目构造工具Make 本系列是博主参考B站课程学习开发一个RISC-V的操作系统的学习笔记,计划从RISC…...

小米机型解锁bl 跳“168小时”限制 操作步骤分析

写到前面的安全提示 了解解锁bl后的风险: 解锁设备后将允许修改系统重要组件,并有可能在一定程度上导致设备受损;解锁后设备安全性将失去保证,易受恶意软件攻击,从而导致个人隐私数据泄露;解锁后部分对系…...

基础练习 回文数

问题描述 1221是一个非常特殊的数&#xff0c;它从左边读和从右边读是一样的&#xff0c;编程求所有这样的四位十进制数。 输出格式 按从小到大的顺序输出满足条件的四位十进制数。 solution1 #include <stdio.h> int main(){int n 1000, n1, n2, n3, n4;while(n &…...

解决Spring Boot 2.7.16 在服务器显示启动成功无法访问问题:从本地到服务器的部署坑

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…...

洛谷P5661:公交换乘 ← CSP-J 2019 复赛第2题

【题目来源】https://www.luogu.com.cn/problem/P5661https://www.acwing.com/problem/content/1164/【题目描述】 著名旅游城市 B 市为了鼓励大家采用公共交通方式出行&#xff0c;推出了一种地铁换乘公交车的优惠方案&#xff1a; 1.在搭乘一次地铁后可以获得一张优惠票&…...

mysql优化之索引

索引官方定义&#xff1a;索引是帮助mysql高效获取数据的数据结构。 索引的目的在于提高查询效率&#xff0c;可以类比字典。 可以简单理解为&#xff1a;排好序的快速查找数据结构 在数据之外&#xff0c;数据库系统还维护着满足特定查找算法的数据结构&#xff0c;这种数据…...

文件系统详解

目录 文件系统&#xff08;1&#xff09; 第一节文件系统的基本概念 一、文件系统的任务 二、文件的存储介质及存储方式 三、文件的分类 第二节 文件的逻辑结构和物理结构 一、文件的逻辑结构 二、文件的物理结构 文件系统&#xff08;2&#xff09; 第三节 文件目…...

有名管道及其应用

创建FIFO文件 1.通过命令&#xff1a; mkfifo 文件名 2.通过函数: mkfifo #include <sys/types.h> #include <sys/stat.h> int mkfifo(const char *pathname, mode_t mode); 参数&#xff1a; -pathname&#xff1a;管道名称的路径 -mode&#xff1a;文件的权限&a…...

基于STM32LXXX的数字电位器(MAX5402EUA+T)驱动应用程序设计

一、简介: MAX5402EUA+T 是Maxim Integrated(现Analog Devices)推出的一款256抽头、低漂移数字电位器,采用MAX-8封装。 二、主要技术特性: 参数 值 抽头数 256个 端到端电阻 10kΩ 每级步进电阻 39.2Ω (10kΩ/255) 接口类型 SPI兼容,3线串行 电源电压 2.7V ~ 5.5V 温度…...

从体素到三维模型:解析Volumetric Method在复杂场景重建中的核心算法

1. 什么是Volumetric Method&#xff1f;从体素到三维世界的魔法 第一次接触三维重建时&#xff0c;我被那些从照片变成立体模型的演示惊呆了。后来才知道&#xff0c;这背后藏着一种叫Volumetric Method的技术&#xff0c;它就像用乐高积木搭建世界——把空间切成无数小方块&a…...

避坑指南:ABB机器人PC SDK开发中,网络扫描(NetworkScanner)为何总为空?

ABB机器人PC SDK网络扫描故障深度排查指南 当你在C#项目中调用NetworkScanner.Scan()方法时&#xff0c;那个本该充满控制器信息的ControllerInfoCollection却固执地保持空白——这种挫败感每个ABB机器人开发者都深有体会。本文将从协议栈底层到网络拓扑&#xff0c;系统性地拆…...

开源项目合规性警示:从PyWxDump案例看技术工具的法律边界

开源项目合规性警示&#xff1a;从PyWxDump案例看技术工具的法律边界 【免费下载链接】PyWxDump 删库 项目地址: https://gitcode.com/GitHub_Trending/py/PyWxDump 在开源技术快速发展的今天&#xff0c;开发者常常面临技术实现与法律合规的平衡难题。近期&#xff0c;…...

ROS新手必看:5分钟搞定usb_cam相机标定(附棋盘格下载)

ROS实战&#xff1a;从零完成USB摄像头标定的完整指南 在机器人视觉系统中&#xff0c;相机标定是确保测量精度的基础步骤。许多ROS初学者往往在第一步就遇到障碍——要么找不到合适的标定工具&#xff0c;要么被复杂的参数配置搞得晕头转向。本文将带你用最直接的方式完成整个…...

Windows HEIC缩略图终极指南:3分钟免费解决iPhone照片预览问题

Windows HEIC缩略图终极指南&#xff1a;3分钟免费解决iPhone照片预览问题 【免费下载链接】windows-heic-thumbnails Enable Windows Explorer to display thumbnails for HEIC/HEIF files 项目地址: https://gitcode.com/gh_mirrors/wi/windows-heic-thumbnails 还在为…...

特征选择实战:用F检验、互信息法搞定Kaggle高维数据,附完整Python代码与避坑指南

特征选择实战&#xff1a;用F检验与互信息法构建高维数据黄金特征集 在Kaggle竞赛和真实业务场景中&#xff0c;我们常常面对成百上千个特征的高维数据集。如何从中筛选出最具预测力的特征子集&#xff1f;本文将带你构建完整的特征选择流水线&#xff0c;从方差过滤到相关性筛…...

3分钟解决魔兽争霸3卡顿难题:WarcraftHelper优化工具全攻略

3分钟解决魔兽争霸3卡顿难题&#xff1a;WarcraftHelper优化工具全攻略 【免费下载链接】WarcraftHelper Warcraft III Helper , support 1.20e, 1.24e, 1.26a, 1.27a, 1.27b 项目地址: https://gitcode.com/gh_mirrors/wa/WarcraftHelper 您是否也曾在重温《魔兽争霸3》…...

Unity路径有中文就报错?手把手教你解决Autoware高精地图插件导入的坑

Unity路径中文报错&#xff1f;Autoware高精地图插件导入全攻略 刚接触Autoware高精地图制作的新手们&#xff0c;十有八九会在第一步就栽跟头——当你兴冲冲下载好vector_map插件&#xff0c;准备在Unity中大展拳脚时&#xff0c;却发现插件死活无法正常导入。这种挫败感我太熟…...

MKDV4GCL-ABB嵌入式存储芯片在智能物联网设备中的关键应用解析

1. 为什么物联网设备需要专用存储芯片&#xff1f; 第一次拆解智能家居设备时&#xff0c;我发现很多厂商都在用TF卡扩展存储。但实际使用三个月后&#xff0c;问题就来了——频繁读写导致卡片损坏&#xff0c;设备不断报存储错误。这就是典型选错存储方案的后果。物联网设备对…...