pytorch学习---实现线性回归初体验
假设我们的基础模型就是y = wx + b
,其中w和b均为参数,我们使用y = 3x+0.8
来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。
步骤如下:
- 准备数据
- 计算预测值
- 计算损失,把参数的梯度置为0,进行反向传播
- 更新参数
方式一
该方式没有用pytorch
的模型api
,手动实现
import torch,numpy
import matplotlib.pyplot as plt# 1、准备数据
learning_rate = 0.01
#y=3x + 0.8
x = torch.rand([500,1])
y_true= x*3 + 0.8# 2、通过模型计算y_predict
w = torch.rand([1,1],requires_grad=True)
b = torch.tensor(0,requires_grad=True,dtype=torch.float32)# 3、通过循环,反向传播,更新参数
for i in range(500):# 4、计算lossy_predict = torch.matmul(x,w) + bloss = (y_true-y_predict).pow(2).mean()# 每次循环判断是否存在梯度,防止累加if w.grad is not None:w.grad.data.zero_()if b.grad is not None:b.grad.data.zero_()# 反向传播loss.backward()w.data = w.data - learning_rate*w.gradb.data = b.data - learning_rate*b.grad# 每50次输出一下结果if i%50==0:print("w,b,loss",w.item(),b.item(),loss.item())#可视化显示
plt.figure(figsize=(20,8))
plt.scatter(x.numpy().reshape(-1),y_true.numpy().reshape(-1))
y_predict = torch.matmul(x,w) + b
plt.plot(x.numpy().reshape(-1),y_predict.detach().numpy().reshape(-1),c="r")
plt.show()
循环500
次的效果
循环2000
次的结果
方式二
方式一的方式虽然已经购简便了,但是还是有些许繁琐,所以我们可以采用pytorch
的api
来实现。
nn.Module
是torch.nn
提供的一个类,是pytorch
中我们自定义网络的一个基类,在这个类中定义了很多有用的方法,让我们在继承这个类定义网络的时候非常简单。
当我们自定义网络的时候,有两个方法需要特别注意:
1.__init__
需要调用super
方法,继承父类的属性和方法
2. forward
方法必须实现,用来定义我们的网络的向前计算的过程用前面的y = wx+b
的模型举例如下:
#定义模型
from torch import nn
class Lr(nn.Module): #继承nn.Moduledef __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1,1)def forward(self,x):out = self.linear(x)return out
全部代码如下:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt#1、定义数据
x = torch.rand([50,1])
y = x*3 + 0.8#定义模型
class Lr(nn.Module): #继承nn.Moduledef __init__(self):super(Lr, self).__init__()self.linear = nn.Linear(1,1)def forward(self,x):out = self.linear(x)return out
#2、实例化模型、loss函数以及优化器
model = Lr()
criterion = nn.MSELoss() #损失函数
optimizer = optim.SGD(model.parameters(),lr=1e-3) #优化器#3、训练模型
for i in range(3000):out = model(x)# 获取预测值loss = criterion(y,out) #计算损失optimizer.zero_grad() #梯度归零loss.backward() #计算梯度optimizer.step() #更新梯度if(i+1) % 20 ==0:print('Epoch[{}/{}],loss:{:.6f}'.format(i,500,loss.data))#4、模型评估
model.eval() #设置模型为评估模式,即预测模式
predict = model(x)
predict = predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()
注意:
model.eval()
表示设置模型为评估模式,即预测模式
model.train(mode=True)
表示设置模型为训练模式
在当前的线性回归中,上述并无区别
但是在其他的一些模型中,训练的参数和预测的参数会不相同,到时候就需要具体告诉程序我们是在进行训练还是预测,比如模型中存在Dropout,BatchNorm的时候
循环2000
次的结果:
循环30000
次的结果:
相关文章:

pytorch学习---实现线性回归初体验
假设我们的基础模型就是y wx b,其中w和b均为参数,我们使用y 3x0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8。 步骤如下: 准备数据计算预测值计算损失,把参数的梯度置为0,进行反向传播…...

别再乱写git commit了
B站|公众号:啥都会一点的研究生 写在前面 在很长的一段时间中,使用git commit都是随心所欲,log肥肠简洁,随着代码的迭代,当时有多偷懒,返过头查看git日志就有多懊悔,就和写代码不写doc string…...

八大排序(一)冒泡排序,选择排序,插入排序,希尔排序
一、冒泡排序 冒泡排序的原理是:从左到右,相邻元素进行比较。每次比较一轮,就会找到序列中最大的一个或最小的一个。这个数就会从序列的最右边冒出来。 以从小到大排序为例,第一轮比较后,所有数中最大的那个数就会浮…...
泊松分布简要介绍
泊松分布是一种常见的离散概率分布,它用于描述某个时间段或区域内随机事件发生的次数。它得名于法国数学家西蒙丹尼泊松。 泊松分布的概率质量函数表示某个时间段或区域内事件发生次数的概率。如果随机变量 X 服从泊松分布,记作 X ~ Poisson(λ)&#x…...

C语言每日一题(10):无人生还
文章主题:无人生还🔥所属专栏:C语言每日一题📗作者简介:每天不定时更新C语言的小白一枚,记录分享自己每天的所思所想😄🎶个人主页:[₽]的个人主页🏄…...
VSCode开发go手记
断点调试: 安装delve(windows): go get -u github.com/go-delve/delve/cmd/dlv 设置 launch.json 配置文件: ctrlshiftp 输入 Debug: Open launch.json 打开 launch.json 文件,如果第一次打开,会新建一…...

怎么选择AI伪原创工具-AI伪原创工具有哪些
在数字时代,创作和发布内容已经成为了一种不可或缺的活动。不论您是个人博主、企业家还是网站管理员,都会面临一个共同的挑战:如何在互联网上脱颖而出,吸引更多的读者和访客。而正是在这个背景下,AI伪原创工具逐渐崭露…...

【块状链表C++】文本编辑器(指针中 引用 的使用)
》》》算法竞赛 /*** file * author jUicE_g2R(qq:3406291309)————彬(bin-必应)* 一个某双流一大学通信与信息专业大二在读 * * brief 一直在竞赛算法学习的路上* * copyright 2023.9* COPYRIGHT 原创技术笔记:转载…...

echarts的Y轴设置为整数
场景:使用echarts,设置Y轴为整数。通过判断Y轴的数值为整数才显示即可 yAxis: [{name: ,type: value,min: 0, // 最小值// max: 200, // 最大值// splitNumber: 5, // 坐标轴的分割段数// interval: 100 / 5, // 强制设置坐标轴分割间隔度(取本Y轴的最大…...

恢复删除文件?不得不掌握的4个方法!
“删除了的文件还可以恢复吗?有个文件我本来以为不重要了,就把它删除了,没想到现在还需要用到!这可怎么办?有没有办法找回来呢?” 重要的文件一旦丢失或误删可能都会对我们的工作和学习造成比较大的影响。怎…...
GitLab CI/CD:.gitlab-ci.yml 文件常用参数小结
文章目录 一、.gitlab-ci.yml 文件作用二、一个简单的.gitlab-ci.yml 文件示例参考 一、.gitlab-ci.yml 文件作用 可以定义跑CI时想要运行的命令或脚本 可以定义job之间的依赖和缓存 可以执行程序部署并定义部署位置 可以定义想要包含的其他配置文件和模版 二、一个简单的.gi…...

MySQL学习笔记9
MySQL数据表中的数据类型: 在考虑数据类型、长度、标度和精度时,一定要仔细地进行短期和长远的规划,另外,公司制度和希望用户用什么方式访问数据也是要考虑的因素。开发人员应该了解数据的本质,以及数据在数据库里是如…...

从零学习开发一个RISC-V操作系统(三)丨嵌入式操作系统开发的常用概念和工具
本篇文章的内容 一、嵌入式操作习系统开发的常用概念和工具1.1 本地编译和交叉编译1.2 调试器GDB(The GNU Project Debugger)1.3 QEMU模拟器1.4 项目构造工具Make 本系列是博主参考B站课程学习开发一个RISC-V的操作系统的学习笔记,计划从RISC…...

小米机型解锁bl 跳“168小时”限制 操作步骤分析
写到前面的安全提示 了解解锁bl后的风险: 解锁设备后将允许修改系统重要组件,并有可能在一定程度上导致设备受损;解锁后设备安全性将失去保证,易受恶意软件攻击,从而导致个人隐私数据泄露;解锁后部分对系…...
基础练习 回文数
问题描述 1221是一个非常特殊的数,它从左边读和从右边读是一样的,编程求所有这样的四位十进制数。 输出格式 按从小到大的顺序输出满足条件的四位十进制数。 solution1 #include <stdio.h> int main(){int n 1000, n1, n2, n3, n4;while(n &…...

解决Spring Boot 2.7.16 在服务器显示启动成功无法访问问题:从本地到服务器的部署坑
🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…...
洛谷P5661:公交换乘 ← CSP-J 2019 复赛第2题
【题目来源】https://www.luogu.com.cn/problem/P5661https://www.acwing.com/problem/content/1164/【题目描述】 著名旅游城市 B 市为了鼓励大家采用公共交通方式出行,推出了一种地铁换乘公交车的优惠方案: 1.在搭乘一次地铁后可以获得一张优惠票&…...

mysql优化之索引
索引官方定义:索引是帮助mysql高效获取数据的数据结构。 索引的目的在于提高查询效率,可以类比字典。 可以简单理解为:排好序的快速查找数据结构 在数据之外,数据库系统还维护着满足特定查找算法的数据结构,这种数据…...

文件系统详解
目录 文件系统(1) 第一节文件系统的基本概念 一、文件系统的任务 二、文件的存储介质及存储方式 三、文件的分类 第二节 文件的逻辑结构和物理结构 一、文件的逻辑结构 二、文件的物理结构 文件系统(2) 第三节 文件目…...

有名管道及其应用
创建FIFO文件 1.通过命令: mkfifo 文件名 2.通过函数: mkfifo #include <sys/types.h> #include <sys/stat.h> int mkfifo(const char *pathname, mode_t mode); 参数: -pathname:管道名称的路径 -mode:文件的权限&a…...

机器学习:决策树和剪枝
本文目录: 一、决策树基本知识(一)概念(二)决策树建立过程 二、决策树生成(一)ID3决策树:基于信息增益构建的决策树。(二)C4.5决策树(三ÿ…...

智绅科技 —— 智慧养老 + 数字健康,构筑银发时代安全防护网
在老龄化率突破 21.3% 的当下,智绅科技以 "科技适老" 为核心理念,构建 "监测 - 预警 - 干预 - 照护" 的智慧养老闭环。 其自主研发的七彩喜智慧康养平台,通过物联网、AI 和边缘计算技术,实现对老年人健康与安…...

数据分析实战2(Tableau)
1、Tableau功能 数据赋能(让业务一线也可以轻松使用最新数据) 分析师可以直接将数据看板发布到线上自动更新看板自由下载数据线上修改图表邮箱发送数据设置数据预警 数据探索(通过统计分析和数据可视化,从数据发现问题…...
PART 6 树莓派小车+QT (TCP控制)
1. 树莓派作为服务器的程序 (1)服务器tcp_server_socket程序 可以实现小车前进、后退、左转、右转、加减速(可能不行) carMoveControl.py import RPi.GPIO as GPIO import time import tty,sys,select,termios import socket…...

spring:实例化类过程中方法执行顺序。
如题。在实例化Bean时,会根据配置依次调用方法。在此测试代码如下: 在测试类中继承接口InitializingBean,接口InterfaceUserService(该接口为自定义,只是定义set方法)。 InterfaceUserService,…...

验证电机理论与性能:电机试验平板提升测试效率
电机试验平板提升测试效率是验证电机理论与性能的重要环节之一。通过在平板上进行电机试验,可以对电机的性能参数进行准确测量和分析,从而验证电机的理论设计是否符合实际表现。同时,提升测试效率可以加快试验过程,节约时间和成本…...
AI系统负载均衡与动态路由
载均衡与动态路由 在微服务架构中,负载均衡是实现服务高可用和性能优化的关键机制。传统负载均衡技术通常围绕请求数、连接数、CPU占用率等基础指标进行分发,而在AI系统中,特别是多模型、多异构算力(如CPU、GPU、TPU)共存的环境下,负载均衡不仅要考虑节点资源消耗,还需…...

微型导轨在手术机器人领域中有哪些关键操作?
在微创手术领域,手术机器人凭借其高精度、高稳定性和远程操控能力,正逐步成为现代外科手术的重要工具。微型导轨作为一种专为高精度运动设计的线性导向系统,凭借其亚微米级定位精度、低摩擦运动特性及紧凑结构设计,已成为手术机器…...
Linux网络——socket网络通信udp
文章目录 UDP通信基础UDP的特点 Linux下UDP通信核心步骤创建UDP套接字绑定本地地址(可选)发送数据函数:sendto()函数原型参数详解典型使用示例 接收数据函数:recvfrom()函数原型参数详解返回值典型使用示例 关键设计原因无连接特性…...
@Prometheus 监控-MySQL (Mysqld Exporter)
文章目录 **Prometheus 监控 MySQL ****1. 目标****2. 环境准备****2.1 所需组件****2.2 权限要求** **3. 部署 mysqld_exporter****3.1 下载与安装****3.2 创建配置文件****3.3 创建 Systemd 服务****3.4 验证 Exporter** **4. 配置 Prometheus****4.1 添加 Job 到 prometheus…...