当前位置: 首页 > news >正文

用Pytorch实现线性回归模型

目录

  • 回顾
  • Pytorch实现
    • 步骤
    • 1. 准备数据
    • 2. 设计模型
      • class LinearModel
      • 代码
    • 3. 构造损失函数和优化器
    • 4. 训练过程
    • 5. 输出和测试
    • 完整代码
  • 练习

回顾

前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:

  • 确定模型(函数)
  • 定义损失函数
  • 优化器优化(SGD)

之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。

Pytorch实现

步骤

  1. 准备数据集
  2. 设计模型(计算预测值y_hat):从nn.Module模块继承
  3. 构造损失函数和优化器:使用PytorchAPI
  4. 训练过程:Forward、Backward、update

1. 准备数据

在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。
在这里插入图片描述

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

2. 设计模型

在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。
在这里插入图片描述
而使用Pytorch构造模型,重点时在构建计算图和损失函数上。
在这里插入图片描述

class LinearModel

通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经网络模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:

  • init():构造函数,进行初始化。
    def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)

在这里插入图片描述

  • forward():进行前馈计算
    (backward没有被写,是因为在这种模型类里面会自动实现)

Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。

    def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_pred

代码

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()

3. 构造损失函数和优化器

采用MSE作为损失函数

torch.nn.MSELoss(size_average,reduce)

  • size_average:是否求mini-batch的平均loss。
  • reduce:降维,不用管。

在这里插入图片描述SGD作为优化器torch.optim.SGD(params, lr):

  • params:参数
  • lr:学习率

在这里插入图片描述

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b

4. 训练过程

  1. 预测
  2. 计算loss
  3. 梯度清零
  4. Backward
  5. 参数更新
    简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新

5. 输出和测试

# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

完整代码

import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])#2. Design Model
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

输出结果:

85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])

可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:

980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])

练习

用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。
在这里插入图片描述
比如说:Adam的loss图:
在这里插入图片描述

相关文章:

用Pytorch实现线性回归模型

目录 回顾Pytorch实现步骤1. 准备数据2. 设计模型class LinearModel代码 3. 构造损失函数和优化器4. 训练过程5. 输出和测试完整代码 练习 回顾 前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。 这节课主要是利用Pytorch实现线性模型。…...

WordPress模板层次与常用模板函数

首页: home.php index.php 文章页: single-{post_type}.php – 如果文章类型是videos(即视频),WordPress就会去查找single-videos.php(WordPress 3.0及以上版本支持) single.php index.php 页面: 自定义模板 – 在WordPre…...

HarmonyOS应用开发者高级认证试题库(鸿蒙)

目录 考试链接: 流程: 选择: 判断 单选 多选 考试链接: 华为开发者学堂华为开发者学堂https://developer.huawei.com/consumer/cn/training/dev-certification/a617e0d3bc144624864a04edb951f6c4 流程: 先进行…...

系分备考计算机网络传输介质、通信方式和交换方式

文章目录 1、概述2、传输介质3、网络通信4、网络交换5、总结 1、概述 计算机网路是系统分析师考试的常考知识点,本篇主要记录了知识点:网络传输介质、网络通信和数据交换方式等。 2、传输介质 网络的传输最常见的就是网线,也就是双绞线&…...

js原生面试总结

冒泡循环 var arr[2,1,3,4,9,7,6,8] // 外层循环代表循环次数 内层循环时每次的两两对比 少一次循环 for (let i 0; i < arr.length-1; i) {// 如果进入判断代表当前值大于下一个是需要进行冒泡排序的let booltruefor (let j 0; j < arr.length-1-i; j) {// 虽然…...

接口自动化测试框架设计

文章目录 接口测试的定义接口测试的意义接口测试的测试用例设计接口测试的测试用例设计方法postman主要功能请求体分类JSON数据类型postman内置参数postman变量全局变量环境变量 postman断言JSON提取器正则表达式提取器Cookie提取器postman加密接口签名 接口自动化测试基础getp…...

详解ISIS动态路由协议

华子目录 前言应用场景历史起源ISIS路由计算过程ISIS的地址结构ISIS路由器分类ISIS邻居关系的建立P2PMA ISIS中的DIS与OSPF中DR的对比链路状态信息的交互ISIS的最短路径优先算法&#xff08;SPF&#xff09;ISIS区域划分ISIS区域间路由访问原理ISIS与OSPF的不同ISIS与OSPF的术语…...

Linux操作系统----gdb调试工具(配实操图)

绪论​ “不用滞留采花保存&#xff0c;只管往前走去&#xff0c;一路上百花自会盛开。 ——泰戈尔”。本章是Linux工具篇的最后一章。gdb调试工具是我们日常工作中需要掌握的一项重要技能我们需要基本的掌握release和debug的区别以及gdb的调试方法的指令。下一章我们将进入真正…...

去除GIT某个时间之前的提交日志

背景 有时git提交了太多有些较早之前的提交日志&#xff0c;不想在git log看到&#xff0c;想把他删除掉。 方法 大概思路是通过 git clone --depth 来克隆到指定提交的代码&#xff0c;此时再早之前的日志是没有的 然后提交到新仓库 #!/bin/bash ori_git"gityour.gi…...

4 python快速上手

计算机常识知识 1.Python代码运行方式2.进制2.1 进制转换 3. 计算机中的单位4.编码4.1 ascii编码4.2 gb-2312编码4.3 unicode4.4 utf-8编码4.5 Python相关的编码 总结 各位小伙伴想要博客相关资料的话关注公众号&#xff1a;chuanyeTry即可领取相关资料&#xff01; 1.Python代…...

单元测试-spring-boot-starter-test+junit5

前言&#xff1a; 开发过程中经常需要写单元测试&#xff0c;记录一下单元测试spring-boot-starter-testjunit5的使用 引入内容&#xff1a; 引用jar包 <!-- SpringBoot测试类依赖 --> <dependency><groupId>org.springframework.boot</groupId><…...

CentOS 7上安装Anaconda 详细教程

目录 1. 下载Anaconda安装脚本2. 校验数据完整性&#xff08;可选&#xff09;3. 运行安装脚本4. 遵循安装指南5. 选择安装位置6. 初始化Anaconda7. 激活安装8. 测试安装9. 更新Anaconda10. 使用Anaconda 1. 下载Anaconda安装脚本 首先需要从Anaconda的官方网站下载最新的Anac…...

2023年全球软件架构师峰会(ArchSummit深圳站):核心内容与学习收获(附大会核心PPT下载)

本次峰会是一次重要的技术盛会&#xff0c;旨在为全球软件架构师提供一个交流和学习的平台。本次峰会聚焦于软件架构的最新趋势、最佳实践和技术创新&#xff0c;吸引了来自世界各地的软件架构师、技术专家和企业领袖。 在峰会中&#xff0c;与会者可以了解到数字化、AIGC、To…...

RT-Thread Studio学习(十六)定时器计数

RT-Thread Studio学习&#xff08;十六&#xff09;定时器计数 一、简介二、新建RT-Thread项目并使用外部时钟三、启用PWM输入捕获功能四、测试 一、简介 本文将基于STM32F407VET芯片介绍如何在RT-Thread Studio开发环境下使用定时器对输入脉冲进行计数。 硬件及开发环境如下…...

【linux进程间通信(一)】匿名管道和命名管道

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:Linux从入门到精通⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学更多操作系统知识   &#x1f51d;&#x1f51d; 进程间通信 1. 前言2. 进程间…...

第11章 jQuery

学习目标 了解什么是jQuery,能够说出jQuery的特点 掌握jQuery的下载和引入,能够下载jQuery并且能够使用两种方式引入jQuery 掌握jQuery的简单使用,能够使用jQuery实现简单的页面效果 熟悉什么是jQuery对象,能够说出jQuery对象与DOM对象的区别 掌握利用选择器获取元素的方法…...

leetcode:1736. 替换隐藏数字得到的最晚时间(python3解法)

难度&#xff1a;简单 给你一个字符串 time &#xff0c;格式为 hh:mm&#xff08;小时&#xff1a;分钟&#xff09;&#xff0c;其中某几位数字被隐藏&#xff08;用 ? 表示&#xff09;。 有效的时间为 00:00 到 23:59 之间的所有时间&#xff0c;包括 00:00 和 23:59 。 …...

MySQL存储函数与存储过程习题

创建表并插入数据&#xff1a; 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 ​ ​ sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 1、创建一个可以统计表格内记录…...

基于 Hologres+Flink 的曹操出行实时数仓建设

本文整理自曹操出行实时计算负责人林震基于 HologresFlink 的曹操出行实时数仓建设的分享&#xff0c;内容主要分为以下六部分&#xff1a; 曹操出行业务背景介绍曹操出行业务痛点分析HologresFlink 构建企业级实时数仓曹操出行实时数仓实践曹操出行业务成果分析未来展望 一、曹…...

【Docker】实战多阶段构建 Laravel 镜像

作者主页&#xff1a; 正函数的个人主页 文章收录专栏&#xff1a; Docker 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01; 本节适用于 PHP 开发者阅读。Laravel 基于 8.x 版本&#xff0c;各个版本的文件结构可能会有差异&#xff0c;请根据实际自行修改。 准备 新…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

【JavaEE】-- HTTP

1. HTTP是什么&#xff1f; HTTP&#xff08;全称为"超文本传输协议"&#xff09;是一种应用非常广泛的应用层协议&#xff0c;HTTP是基于TCP协议的一种应用层协议。 应用层协议&#xff1a;是计算机网络协议栈中最高层的协议&#xff0c;它定义了运行在不同主机上…...

Spring Boot 实现流式响应(兼容 2.7.x)

在实际开发中&#xff0c;我们可能会遇到一些流式数据处理的场景&#xff0c;比如接收来自上游接口的 Server-Sent Events&#xff08;SSE&#xff09; 或 流式 JSON 内容&#xff0c;并将其原样中转给前端页面或客户端。这种情况下&#xff0c;传统的 RestTemplate 缓存机制会…...

Day131 | 灵神 | 回溯算法 | 子集型 子集

Day131 | 灵神 | 回溯算法 | 子集型 子集 78.子集 78. 子集 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 笔者写过很多次这道题了&#xff0c;不想写题解了&#xff0c;大家看灵神讲解吧 回溯算法套路①子集型回溯【基础算法精讲 14】_哔哩哔哩_bilibili 完…...

CMake 从 GitHub 下载第三方库并使用

有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...

Python基于历史模拟方法实现投资组合风险管理的VaR与ES模型项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档&#xff09;&#xff0c;如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 在金融市场日益复杂和波动加剧的背景下&#xff0c;风险管理成为金融机构和个人投资者关注的核心议题之一。VaR&…...

Go 并发编程基础:通道(Channel)的使用

在 Go 中&#xff0c;Channel 是 Goroutine 之间通信的核心机制。它提供了一个线程安全的通信方式&#xff0c;用于在多个 Goroutine 之间传递数据&#xff0c;从而实现高效的并发编程。 本章将介绍 Channel 的基本概念、用法、缓冲、关闭机制以及 select 的使用。 一、Channel…...

[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...

破解路内监管盲区:免布线低位视频桩重塑停车管理新标准

城市路内停车管理常因行道树遮挡、高位设备盲区等问题&#xff0c;导致车牌识别率低、逃费率高&#xff0c;传统模式在复杂路段束手无策。免布线低位视频桩凭借超低视角部署与智能算法&#xff0c;正成为破局关键。该设备安装于车位侧方0.5-0.7米高度&#xff0c;直接规避树枝遮…...

【SpringBoot自动化部署】

SpringBoot自动化部署方法 使用Jenkins进行持续集成与部署 Jenkins是最常用的自动化部署工具之一&#xff0c;能够实现代码拉取、构建、测试和部署的全流程自动化。 配置Jenkins任务时&#xff0c;需要添加Git仓库地址和凭证&#xff0c;设置构建触发器&#xff08;如GitHub…...