当前位置: 首页 > news >正文

用Pytorch实现线性回归模型

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经网络模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_pred

代码

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])#2. Design Model
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

相关文章:

用Pytorch实现线性回归模型

目录 回顾Pytorch实现步骤1. 准备数据2. 设计模型class LinearModel代码 3. 构造损失函数和优化器4. 训练过程5. 输出和测试完整代码 练习 回顾 前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。 这节课主要是利用Pytorch实现线性模型。…...

WordPress模板层次与常用模板函数

首页: home.php index.php 文章页: single-{post_type}.php – 如果文章类型是videos(即视频),WordPress就会去查找single-videos.php(WordPress 3.0及以上版本支持) single.php index.php 页面: 自定义模板 – 在WordPre…...

HarmonyOS应用开发者高级认证试题库(鸿蒙)

目录 考试链接: 流程: 选择: 判断 单选 多选 考试链接: 华为开发者学堂华为开发者学堂https://developer.huawei.com/consumer/cn/training/dev-certification/a617e0d3bc144624864a04edb951f6c4 流程: 先进行…...

系分备考计算机网络传输介质、通信方式和交换方式

文章目录 1、概述2、传输介质3、网络通信4、网络交换5、总结 1、概述 计算机网路是系统分析师考试的常考知识点,本篇主要记录了知识点:网络传输介质、网络通信和数据交换方式等。 2、传输介质 网络的传输最常见的就是网线,也就是双绞线&…...

js原生面试总结

冒泡循环 var arr[2,1,3,4,9,7,6,8] // 外层循环代表循环次数 内层循环时每次的两两对比 少一次循环 for (let i 0; i < arr.length-1; i) {// 如果进入判断代表当前值大于下一个是需要进行冒泡排序的let booltruefor (let j 0; j < arr.length-1-i; j) {// 虽然…...

接口自动化测试框架设计

文章目录 接口测试的定义接口测试的意义接口测试的测试用例设计接口测试的测试用例设计方法postman主要功能请求体分类JSON数据类型postman内置参数postman变量全局变量环境变量 postman断言JSON提取器正则表达式提取器Cookie提取器postman加密接口签名 接口自动化测试基础getp…...

详解ISIS动态路由协议

华子目录 前言应用场景历史起源ISIS路由计算过程ISIS的地址结构ISIS路由器分类ISIS邻居关系的建立P2PMA ISIS中的DIS与OSPF中DR的对比链路状态信息的交互ISIS的最短路径优先算法&#xff08;SPF&#xff09;ISIS区域划分ISIS区域间路由访问原理ISIS与OSPF的不同ISIS与OSPF的术语…...

Linux操作系统----gdb调试工具(配实操图)

绪论​ “不用滞留采花保存&#xff0c;只管往前走去&#xff0c;一路上百花自会盛开。 ——泰戈尔”。本章是Linux工具篇的最后一章。gdb调试工具是我们日常工作中需要掌握的一项重要技能我们需要基本的掌握release和debug的区别以及gdb的调试方法的指令。下一章我们将进入真正…...

去除GIT某个时间之前的提交日志

背景 有时git提交了太多有些较早之前的提交日志&#xff0c;不想在git log看到&#xff0c;想把他删除掉。 方法 大概思路是通过 git clone --depth 来克隆到指定提交的代码&#xff0c;此时再早之前的日志是没有的 然后提交到新仓库 #!/bin/bash ori_git"gityour.gi…...

4 python快速上手

计算机常识知识 1.Python代码运行方式2.进制2.1 进制转换 3. 计算机中的单位4.编码4.1 ascii编码4.2 gb-2312编码4.3 unicode4.4 utf-8编码4.5 Python相关的编码 总结 各位小伙伴想要博客相关资料的话关注公众号&#xff1a;chuanyeTry即可领取相关资料&#xff01; 1.Python代…...

单元测试-spring-boot-starter-test+junit5

前言&#xff1a; 开发过程中经常需要写单元测试&#xff0c;记录一下单元测试spring-boot-starter-testjunit5的使用 引入内容&#xff1a; 引用jar包 <!-- SpringBoot测试类依赖 --> <dependency><groupId>org.springframework.boot</groupId><…...

CentOS 7上安装Anaconda 详细教程

目录 1. 下载Anaconda安装脚本2. 校验数据完整性&#xff08;可选&#xff09;3. 运行安装脚本4. 遵循安装指南5. 选择安装位置6. 初始化Anaconda7. 激活安装8. 测试安装9. 更新Anaconda10. 使用Anaconda 1. 下载Anaconda安装脚本 首先需要从Anaconda的官方网站下载最新的Anac…...

2023年全球软件架构师峰会(ArchSummit深圳站):核心内容与学习收获(附大会核心PPT下载)

本次峰会是一次重要的技术盛会&#xff0c;旨在为全球软件架构师提供一个交流和学习的平台。本次峰会聚焦于软件架构的最新趋势、最佳实践和技术创新&#xff0c;吸引了来自世界各地的软件架构师、技术专家和企业领袖。 在峰会中&#xff0c;与会者可以了解到数字化、AIGC、To…...

RT-Thread Studio学习(十六)定时器计数

RT-Thread Studio学习&#xff08;十六&#xff09;定时器计数 一、简介二、新建RT-Thread项目并使用外部时钟三、启用PWM输入捕获功能四、测试 一、简介 本文将基于STM32F407VET芯片介绍如何在RT-Thread Studio开发环境下使用定时器对输入脉冲进行计数。 硬件及开发环境如下…...

【linux进程间通信(一)】匿名管道和命名管道

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Linux从入门到精通⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学更多操作系统知识   &#x1f51d;&#x1f51d; 进程间通信 1. 前言2. 进程间…...

第11章 jQuery

学习目标 了解什么是jQuery,能够说出jQuery的特点 掌握jQuery的下载和引入,能够下载jQuery并且能够使用两种方式引入jQuery 掌握jQuery的简单使用,能够使用jQuery实现简单的页面效果 熟悉什么是jQuery对象,能够说出jQuery对象与DOM对象的区别 掌握利用选择器获取元素的方法…...

leetcode:1736. 替换隐藏数字得到的最晚时间(python3解法)

难度&#xff1a;简单 给你一个字符串 time &#xff0c;格式为 hh:mm&#xff08;小时&#xff1a;分钟&#xff09;&#xff0c;其中某几位数字被隐藏&#xff08;用 ? 表示&#xff09;。 有效的时间为 00:00 到 23:59 之间的所有时间&#xff0c;包括 00:00 和 23:59 。 …...

MySQL存储函数与存储过程习题

创建表并插入数据&#xff1a; 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 ​ ​ sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 1、创建一个可以统计表格内记录…...

基于 Hologres+Flink 的曹操出行实时数仓建设

本文整理自曹操出行实时计算负责人林震基于 HologresFlink 的曹操出行实时数仓建设的分享&#xff0c;内容主要分为以下六部分&#xff1a; 曹操出行业务背景介绍曹操出行业务痛点分析HologresFlink 构建企业级实时数仓曹操出行实时数仓实践曹操出行业务成果分析未来展望 一、曹…...

【Docker】实战多阶段构建 Laravel 镜像

作者主页&#xff1a; 正函数的个人主页 文章收录专栏&#xff1a; Docker 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01; 本节适用于 PHP 开发者阅读。Laravel 基于 8.x 版本&#xff0c;各个版本的文件结构可能会有差异&#xff0c;请根据实际自行修改。 准备 新…...

KubeSphere 容器平台高可用:环境搭建与可视化操作指南

Linux_k8s篇 欢迎来到Linux的世界&#xff0c;看笔记好好学多敲多打&#xff0c;每个人都是大神&#xff01; 题目&#xff1a;KubeSphere 容器平台高可用&#xff1a;环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

Linux 文件类型,目录与路径,文件与目录管理

文件类型 后面的字符表示文件类型标志 普通文件&#xff1a;-&#xff08;纯文本文件&#xff0c;二进制文件&#xff0c;数据格式文件&#xff09; 如文本文件、图片、程序文件等。 目录文件&#xff1a;d&#xff08;directory&#xff09; 用来存放其他文件或子目录。 设备…...

在HarmonyOS ArkTS ArkUI-X 5.0及以上版本中,手势开发全攻略:

在 HarmonyOS 应用开发中&#xff0c;手势交互是连接用户与设备的核心纽带。ArkTS 框架提供了丰富的手势处理能力&#xff0c;既支持点击、长按、拖拽等基础单一手势的精细控制&#xff0c;也能通过多种绑定策略解决父子组件的手势竞争问题。本文将结合官方开发文档&#xff0c…...

苍穹外卖--缓存菜品

1.问题说明 用户端小程序展示的菜品数据都是通过查询数据库获得&#xff0c;如果用户端访问量比较大&#xff0c;数据库访问压力随之增大 2.实现思路 通过Redis来缓存菜品数据&#xff0c;减少数据库查询操作。 缓存逻辑分析&#xff1a; ①每个分类下的菜品保持一份缓存数据…...

Matlab | matlab常用命令总结

常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...

网络编程(UDP编程)

思维导图 UDP基础编程&#xff08;单播&#xff09; 1.流程图 服务器&#xff1a;短信的接收方 创建套接字 (socket)-----------------------------------------》有手机指定网络信息-----------------------------------------------》有号码绑定套接字 (bind)--------------…...

C++八股 —— 单例模式

文章目录 1. 基本概念2. 设计要点3. 实现方式4. 详解懒汉模式 1. 基本概念 线程安全&#xff08;Thread Safety&#xff09; 线程安全是指在多线程环境下&#xff0c;某个函数、类或代码片段能够被多个线程同时调用时&#xff0c;仍能保证数据的一致性和逻辑的正确性&#xf…...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...

Qemu arm操作系统开发环境

使用qemu虚拟arm硬件比较合适。 步骤如下&#xff1a; 安装qemu apt install qemu-system安装aarch64-none-elf-gcc 需要手动下载&#xff0c;下载地址&#xff1a;https://developer.arm.com/-/media/Files/downloads/gnu/13.2.rel1/binrel/arm-gnu-toolchain-13.2.rel1-x…...

解析奥地利 XARION激光超声检测系统:无膜光学麦克风 + 无耦合剂的技术协同优势及多元应用

在工业制造领域&#xff0c;无损检测&#xff08;NDT)的精度与效率直接影响产品质量与生产安全。奥地利 XARION开发的激光超声精密检测系统&#xff0c;以非接触式光学麦克风技术为核心&#xff0c;打破传统检测瓶颈&#xff0c;为半导体、航空航天、汽车制造等行业提供了高灵敏…...