当前位置: 首页 > news >正文

用Pytorch实现线性回归模型

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经网络模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_pred

代码

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])#2. Design Model
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

相关文章:

用Pytorch实现线性回归模型

目录 回顾Pytorch实现步骤1. 准备数据2. 设计模型class LinearModel代码 3. 构造损失函数和优化器4. 训练过程5. 输出和测试完整代码 练习 回顾 前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。 这节课主要是利用Pytorch实现线性模型。…...

WordPress模板层次与常用模板函数

首页: home.php index.php 文章页: single-{post_type}.php – 如果文章类型是videos(即视频),WordPress就会去查找single-videos.php(WordPress 3.0及以上版本支持) single.php index.php 页面: 自定义模板 – 在WordPre…...

HarmonyOS应用开发者高级认证试题库(鸿蒙)

目录 考试链接: 流程: 选择: 判断 单选 多选 考试链接: 华为开发者学堂华为开发者学堂https://developer.huawei.com/consumer/cn/training/dev-certification/a617e0d3bc144624864a04edb951f6c4 流程: 先进行…...

系分备考计算机网络传输介质、通信方式和交换方式

文章目录 1、概述2、传输介质3、网络通信4、网络交换5、总结 1、概述 计算机网路是系统分析师考试的常考知识点,本篇主要记录了知识点:网络传输介质、网络通信和数据交换方式等。 2、传输介质 网络的传输最常见的就是网线,也就是双绞线&…...

js原生面试总结

冒泡循环 var arr[2,1,3,4,9,7,6,8] // 外层循环代表循环次数 内层循环时每次的两两对比 少一次循环 for (let i 0; i < arr.length-1; i) {// 如果进入判断代表当前值大于下一个是需要进行冒泡排序的let booltruefor (let j 0; j < arr.length-1-i; j) {// 虽然…...

接口自动化测试框架设计

文章目录 接口测试的定义接口测试的意义接口测试的测试用例设计接口测试的测试用例设计方法postman主要功能请求体分类JSON数据类型postman内置参数postman变量全局变量环境变量 postman断言JSON提取器正则表达式提取器Cookie提取器postman加密接口签名 接口自动化测试基础getp…...

详解ISIS动态路由协议

华子目录 前言应用场景历史起源ISIS路由计算过程ISIS的地址结构ISIS路由器分类ISIS邻居关系的建立P2PMA ISIS中的DIS与OSPF中DR的对比链路状态信息的交互ISIS的最短路径优先算法&#xff08;SPF&#xff09;ISIS区域划分ISIS区域间路由访问原理ISIS与OSPF的不同ISIS与OSPF的术语…...

Linux操作系统----gdb调试工具(配实操图)

绪论​ “不用滞留采花保存&#xff0c;只管往前走去&#xff0c;一路上百花自会盛开。 ——泰戈尔”。本章是Linux工具篇的最后一章。gdb调试工具是我们日常工作中需要掌握的一项重要技能我们需要基本的掌握release和debug的区别以及gdb的调试方法的指令。下一章我们将进入真正…...

去除GIT某个时间之前的提交日志

背景 有时git提交了太多有些较早之前的提交日志&#xff0c;不想在git log看到&#xff0c;想把他删除掉。 方法 大概思路是通过 git clone --depth 来克隆到指定提交的代码&#xff0c;此时再早之前的日志是没有的 然后提交到新仓库 #!/bin/bash ori_git"gityour.gi…...

4 python快速上手

计算机常识知识 1.Python代码运行方式2.进制2.1 进制转换 3. 计算机中的单位4.编码4.1 ascii编码4.2 gb-2312编码4.3 unicode4.4 utf-8编码4.5 Python相关的编码 总结 各位小伙伴想要博客相关资料的话关注公众号&#xff1a;chuanyeTry即可领取相关资料&#xff01; 1.Python代…...

单元测试-spring-boot-starter-test+junit5

前言&#xff1a; 开发过程中经常需要写单元测试&#xff0c;记录一下单元测试spring-boot-starter-testjunit5的使用 引入内容&#xff1a; 引用jar包 <!-- SpringBoot测试类依赖 --> <dependency><groupId>org.springframework.boot</groupId><…...

CentOS 7上安装Anaconda 详细教程

目录 1. 下载Anaconda安装脚本2. 校验数据完整性&#xff08;可选&#xff09;3. 运行安装脚本4. 遵循安装指南5. 选择安装位置6. 初始化Anaconda7. 激活安装8. 测试安装9. 更新Anaconda10. 使用Anaconda 1. 下载Anaconda安装脚本 首先需要从Anaconda的官方网站下载最新的Anac…...

2023年全球软件架构师峰会(ArchSummit深圳站):核心内容与学习收获(附大会核心PPT下载)

本次峰会是一次重要的技术盛会&#xff0c;旨在为全球软件架构师提供一个交流和学习的平台。本次峰会聚焦于软件架构的最新趋势、最佳实践和技术创新&#xff0c;吸引了来自世界各地的软件架构师、技术专家和企业领袖。 在峰会中&#xff0c;与会者可以了解到数字化、AIGC、To…...

RT-Thread Studio学习(十六)定时器计数

RT-Thread Studio学习&#xff08;十六&#xff09;定时器计数 一、简介二、新建RT-Thread项目并使用外部时钟三、启用PWM输入捕获功能四、测试 一、简介 本文将基于STM32F407VET芯片介绍如何在RT-Thread Studio开发环境下使用定时器对输入脉冲进行计数。 硬件及开发环境如下…...

【linux进程间通信(一)】匿名管道和命名管道

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Linux从入门到精通⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学更多操作系统知识   &#x1f51d;&#x1f51d; 进程间通信 1. 前言2. 进程间…...

第11章 jQuery

学习目标 了解什么是jQuery,能够说出jQuery的特点 掌握jQuery的下载和引入,能够下载jQuery并且能够使用两种方式引入jQuery 掌握jQuery的简单使用,能够使用jQuery实现简单的页面效果 熟悉什么是jQuery对象,能够说出jQuery对象与DOM对象的区别 掌握利用选择器获取元素的方法…...

leetcode:1736. 替换隐藏数字得到的最晚时间(python3解法)

难度&#xff1a;简单 给你一个字符串 time &#xff0c;格式为 hh:mm&#xff08;小时&#xff1a;分钟&#xff09;&#xff0c;其中某几位数字被隐藏&#xff08;用 ? 表示&#xff09;。 有效的时间为 00:00 到 23:59 之间的所有时间&#xff0c;包括 00:00 和 23:59 。 …...

MySQL存储函数与存储过程习题

创建表并插入数据&#xff1a; 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 ​ ​ sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 1、创建一个可以统计表格内记录…...

基于 Hologres+Flink 的曹操出行实时数仓建设

本文整理自曹操出行实时计算负责人林震基于 HologresFlink 的曹操出行实时数仓建设的分享&#xff0c;内容主要分为以下六部分&#xff1a; 曹操出行业务背景介绍曹操出行业务痛点分析HologresFlink 构建企业级实时数仓曹操出行实时数仓实践曹操出行业务成果分析未来展望 一、曹…...

【Docker】实战多阶段构建 Laravel 镜像

作者主页&#xff1a; 正函数的个人主页 文章收录专栏&#xff1a; Docker 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01; 本节适用于 PHP 开发者阅读。Laravel 基于 8.x 版本&#xff0c;各个版本的文件结构可能会有差异&#xff0c;请根据实际自行修改。 准备 新…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

pam_env.so模块配置解析

在PAM&#xff08;Pluggable Authentication Modules&#xff09;配置中&#xff0c; /etc/pam.d/su 文件相关配置含义如下&#xff1a; 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块&#xff0c;负责验证用户身份&am…...

Python实现prophet 理论及参数优化

文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候&#xff0c;写过一篇简单实现&#xff0c;后期随着对该模型的深入研究&#xff0c;本次记录涉及到prophet 的公式以及参数调优&#xff0c;从公式可以更直观…...

相机Camera日志分析之三十一:高通Camx HAL十种流程基础分析关键字汇总(后续持续更新中)

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了:有对最普通的场景进行各个日志注释讲解,但相机场景太多,日志差异也巨大。后面将展示各种场景下的日志。 通过notepad++打开场景下的日志,通过下列分类关键字搜索,即可清晰的分析不同场景的相机运行流程差异…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

laravel8+vue3.0+element-plus搭建方法

创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...

Python竞赛环境搭建全攻略

Python环境搭建竞赛技术文章大纲 竞赛背景与意义 竞赛的目的与价值Python在竞赛中的应用场景环境搭建对竞赛效率的影响 竞赛环境需求分析 常见竞赛类型&#xff08;算法、数据分析、机器学习等&#xff09;不同竞赛对Python版本及库的要求硬件与操作系统的兼容性问题 Pyth…...

LangChain 中的文档加载器(Loader)与文本切分器(Splitter)详解《二》

&#x1f9e0; LangChain 中 TextSplitter 的使用详解&#xff1a;从基础到进阶&#xff08;附代码&#xff09; 一、前言 在处理大规模文本数据时&#xff0c;特别是在构建知识库或进行大模型训练与推理时&#xff0c;文本切分&#xff08;Text Splitting&#xff09; 是一个…...

生信服务器 | 做生信为什么推荐使用Linux服务器?

原文链接&#xff1a;生信服务器 | 做生信为什么推荐使用Linux服务器&#xff1f; 一、 做生信为什么推荐使用服务器&#xff1f; 大家好&#xff0c;我是小杜。在做生信分析的同学&#xff0c;或是将接触学习生信分析的同学&#xff0c;<font style"color:rgb(53, 1…...