用Pytorch实现线性回归模型
目录
- 回顾
- Pytorch实现
- 步骤
- 1. 准备数据
- 2. 设计模型
- class LinearModel
- 代码
- 3. 构造损失函数和优化器
- 4. 训练过程
- 5. 输出和测试
- 完整代码
- 练习
回顾
前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。
这节课主要是利用Pytorch实现线性模型。
学习器训练:
- 确定模型(函数)
- 定义损失函数
- 优化器优化(SGD)
之前用过Pytorch的Tensor进行Forward、Backward计算。
现在利用Pytorch框架来实现。
Pytorch实现
步骤
- 准备数据集
- 设计模型(计算预测值y_hat):从nn.Module模块继承
- 构造损失函数和优化器:使用PytorchAPI
- 训练过程:Forward、Backward、update
1. 准备数据
在PyTorch中计算图是通过mini-batch形式进行,所以X、Y都是多维的Tensor。

import torch
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
2. 设计模型
在之前讲解梯度下降算法时,我们需要自己计算出梯度,然后更新权重。

而使用Pytorch构造模型,重点时在构建计算图和损失函数上。

class LinearModel
通过构造一个 class LinearModel类来实现,所有的模型类都需要继承nn.Module,这是所有神经网络模块的基础类。
class LinearModel这种定义的模型类必须包含两个部分:
- init():构造函数,进行初始化。
def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)

- forward():进行前馈计算
(backward没有被写,是因为在这种模型类里面会自动实现)
Class nn.Linear 实现了magic method call():它使类的实例可以像函数一样被调用。通常会调用forward()。
def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_pred
代码
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()
3. 构造损失函数和优化器
采用MSE作为损失函数
torch.nn.MSELoss(size_average,reduce)
- size_average:是否求mini-batch的平均loss。
- reduce:降维,不用管。
SGD作为优化器torch.optim.SGD(params, lr):
- params:参数
- lr:学习率

criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b
4. 训练过程
- 预测
- 计算loss
- 梯度清零
- Backward
- 参数更新
简化:Forward–>Backward–>更新
#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新
5. 输出和测试
# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
完整代码
import torch
#1. Prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])#2. Design Model
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__()#调用父类构造函数,不用管,照着写。# torch.nn.Linear(in_featuers, in_featuers)构造Linear类的对象,其实就是实现了一个线性单元self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x)#调用linear对象,输入x进行预测return y_predmodel = LinearModel()#实例化LinearModel()# 3. Construct Loss and Optimize
criterion = torch.nn.MSELoss(size_average=False)#size_average:the losses are averaged over each loss element in the batch.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)#params:model.parameters(): w、b#4. Training Cycle
for epoch in range(100):y_pred = model(x_data)#Forward:预测loss = criterion(y_pred, y_data)#Forward:计算lossprint(epoch, loss)optimizer.zero_grad()#梯度清零loss.backward()#backward:计算梯度optimizer.step()#通过step()函数进行参数更新# Output weight and bias
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())# Test Model
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)
输出结果:
85 tensor(0.2294, grad_fn=)
86 tensor(0.2261, grad_fn=)
87 tensor(0.2228, grad_fn=)
88 tensor(0.2196, grad_fn=)
89 tensor(0.2165, grad_fn=)
90 tensor(0.2134, grad_fn=)
91 tensor(0.2103, grad_fn=)
92 tensor(0.2073, grad_fn=)
93 tensor(0.2043, grad_fn=)
94 tensor(0.2014, grad_fn=)
95 tensor(0.1985, grad_fn=)
96 tensor(0.1956, grad_fn=)
97 tensor(0.1928, grad_fn=)
98 tensor(0.1900, grad_fn=)
99 tensor(0.1873, grad_fn=)
w = 1.711882472038269
b = 0.654958963394165
y_pred = tensor([[7.5025]])
可以看到误差还比较大,可以增加训练轮次,训练1000次后的结果:
980 tensor(2.1981e-07, grad_fn=)
981 tensor(2.1671e-07, grad_fn=)
982 tensor(2.1329e-07, grad_fn=)
983 tensor(2.1032e-07, grad_fn=)
984 tensor(2.0737e-07, grad_fn=)
985 tensor(2.0420e-07, grad_fn=)
986 tensor(2.0143e-07, grad_fn=)
987 tensor(1.9854e-07, grad_fn=)
988 tensor(1.9565e-07, grad_fn=)
989 tensor(1.9260e-07, grad_fn=)
990 tensor(1.8995e-07, grad_fn=)
991 tensor(1.8728e-07, grad_fn=)
992 tensor(1.8464e-07, grad_fn=)
993 tensor(1.8188e-07, grad_fn=)
994 tensor(1.7924e-07, grad_fn=)
995 tensor(1.7669e-07, grad_fn=)
996 tensor(1.7435e-07, grad_fn=)
997 tensor(1.7181e-07, grad_fn=)
998 tensor(1.6931e-07, grad_fn=)
999 tensor(1.6700e-07, grad_fn=)
w = 1.9997280836105347
b = 0.0006181497010402381
y_pred = tensor([[7.9995]])
练习
用以下这些优化器替换SGD,得到训练结果并画出损失曲线图。

比如说:Adam的loss图:

相关文章:
用Pytorch实现线性回归模型
目录 回顾Pytorch实现步骤1. 准备数据2. 设计模型class LinearModel代码 3. 构造损失函数和优化器4. 训练过程5. 输出和测试完整代码 练习 回顾 前面已经学习过线性模型相关的内容,实现线性模型的过程并没有使用到Pytorch。 这节课主要是利用Pytorch实现线性模型。…...
WordPress模板层次与常用模板函数
首页: home.php index.php 文章页: single-{post_type}.php – 如果文章类型是videos(即视频),WordPress就会去查找single-videos.php(WordPress 3.0及以上版本支持) single.php index.php 页面: 自定义模板 – 在WordPre…...
HarmonyOS应用开发者高级认证试题库(鸿蒙)
目录 考试链接: 流程: 选择: 判断 单选 多选 考试链接: 华为开发者学堂华为开发者学堂https://developer.huawei.com/consumer/cn/training/dev-certification/a617e0d3bc144624864a04edb951f6c4 流程: 先进行…...
系分备考计算机网络传输介质、通信方式和交换方式
文章目录 1、概述2、传输介质3、网络通信4、网络交换5、总结 1、概述 计算机网路是系统分析师考试的常考知识点,本篇主要记录了知识点:网络传输介质、网络通信和数据交换方式等。 2、传输介质 网络的传输最常见的就是网线,也就是双绞线&…...
js原生面试总结
冒泡循环 var arr[2,1,3,4,9,7,6,8] // 外层循环代表循环次数 内层循环时每次的两两对比 少一次循环 for (let i 0; i < arr.length-1; i) {// 如果进入判断代表当前值大于下一个是需要进行冒泡排序的let booltruefor (let j 0; j < arr.length-1-i; j) {// 虽然…...
接口自动化测试框架设计
文章目录 接口测试的定义接口测试的意义接口测试的测试用例设计接口测试的测试用例设计方法postman主要功能请求体分类JSON数据类型postman内置参数postman变量全局变量环境变量 postman断言JSON提取器正则表达式提取器Cookie提取器postman加密接口签名 接口自动化测试基础getp…...
详解ISIS动态路由协议
华子目录 前言应用场景历史起源ISIS路由计算过程ISIS的地址结构ISIS路由器分类ISIS邻居关系的建立P2PMA ISIS中的DIS与OSPF中DR的对比链路状态信息的交互ISIS的最短路径优先算法(SPF)ISIS区域划分ISIS区域间路由访问原理ISIS与OSPF的不同ISIS与OSPF的术语…...
Linux操作系统----gdb调试工具(配实操图)
绪论 “不用滞留采花保存,只管往前走去,一路上百花自会盛开。 ——泰戈尔”。本章是Linux工具篇的最后一章。gdb调试工具是我们日常工作中需要掌握的一项重要技能我们需要基本的掌握release和debug的区别以及gdb的调试方法的指令。下一章我们将进入真正…...
去除GIT某个时间之前的提交日志
背景 有时git提交了太多有些较早之前的提交日志,不想在git log看到,想把他删除掉。 方法 大概思路是通过 git clone --depth 来克隆到指定提交的代码,此时再早之前的日志是没有的 然后提交到新仓库 #!/bin/bash ori_git"gityour.gi…...
4 python快速上手
计算机常识知识 1.Python代码运行方式2.进制2.1 进制转换 3. 计算机中的单位4.编码4.1 ascii编码4.2 gb-2312编码4.3 unicode4.4 utf-8编码4.5 Python相关的编码 总结 各位小伙伴想要博客相关资料的话关注公众号:chuanyeTry即可领取相关资料! 1.Python代…...
单元测试-spring-boot-starter-test+junit5
前言: 开发过程中经常需要写单元测试,记录一下单元测试spring-boot-starter-testjunit5的使用 引入内容: 引用jar包 <!-- SpringBoot测试类依赖 --> <dependency><groupId>org.springframework.boot</groupId><…...
CentOS 7上安装Anaconda 详细教程
目录 1. 下载Anaconda安装脚本2. 校验数据完整性(可选)3. 运行安装脚本4. 遵循安装指南5. 选择安装位置6. 初始化Anaconda7. 激活安装8. 测试安装9. 更新Anaconda10. 使用Anaconda 1. 下载Anaconda安装脚本 首先需要从Anaconda的官方网站下载最新的Anac…...
2023年全球软件架构师峰会(ArchSummit深圳站):核心内容与学习收获(附大会核心PPT下载)
本次峰会是一次重要的技术盛会,旨在为全球软件架构师提供一个交流和学习的平台。本次峰会聚焦于软件架构的最新趋势、最佳实践和技术创新,吸引了来自世界各地的软件架构师、技术专家和企业领袖。 在峰会中,与会者可以了解到数字化、AIGC、To…...
RT-Thread Studio学习(十六)定时器计数
RT-Thread Studio学习(十六)定时器计数 一、简介二、新建RT-Thread项目并使用外部时钟三、启用PWM输入捕获功能四、测试 一、简介 本文将基于STM32F407VET芯片介绍如何在RT-Thread Studio开发环境下使用定时器对输入脉冲进行计数。 硬件及开发环境如下…...
【linux进程间通信(一)】匿名管道和命名管道
💓博主CSDN主页:杭电码农-NEO💓 ⏩专栏分类:Linux从入门到精通⏪ 🚚代码仓库:NEO的学习日记🚚 🌹关注我🫵带你学更多操作系统知识 🔝🔝 进程间通信 1. 前言2. 进程间…...
第11章 jQuery
学习目标 了解什么是jQuery,能够说出jQuery的特点 掌握jQuery的下载和引入,能够下载jQuery并且能够使用两种方式引入jQuery 掌握jQuery的简单使用,能够使用jQuery实现简单的页面效果 熟悉什么是jQuery对象,能够说出jQuery对象与DOM对象的区别 掌握利用选择器获取元素的方法…...
leetcode:1736. 替换隐藏数字得到的最晚时间(python3解法)
难度:简单 给你一个字符串 time ,格式为 hh:mm(小时:分钟),其中某几位数字被隐藏(用 ? 表示)。 有效的时间为 00:00 到 23:59 之间的所有时间,包括 00:00 和 23:59 。 …...
MySQL存储函数与存储过程习题
创建表并插入数据: 字段名 数据类型 主键 外键 非空 唯一 自增 id INT 是 否 是 是 否 name VARCHAR(50) 否 否 是 否 否 glass VARCHAR(50) 否 否 是 否 否 sch 表内容 id name glass 1 xiaommg glass 1 2 xiaojun glass 2 1、创建一个可以统计表格内记录…...
基于 Hologres+Flink 的曹操出行实时数仓建设
本文整理自曹操出行实时计算负责人林震基于 HologresFlink 的曹操出行实时数仓建设的分享,内容主要分为以下六部分: 曹操出行业务背景介绍曹操出行业务痛点分析HologresFlink 构建企业级实时数仓曹操出行实时数仓实践曹操出行业务成果分析未来展望 一、曹…...
【Docker】实战多阶段构建 Laravel 镜像
作者主页: 正函数的个人主页 文章收录专栏: Docker 欢迎大家点赞 👍 收藏 ⭐ 加关注哦! 本节适用于 PHP 开发者阅读。Laravel 基于 8.x 版本,各个版本的文件结构可能会有差异,请根据实际自行修改。 准备 新…...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
【kafka】Golang实现分布式Masscan任务调度系统
要求: 输出两个程序,一个命令行程序(命令行参数用flag)和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽,然后将消息推送到kafka里面。 服务端程序: 从kafka消费者接收…...
8k长序列建模,蛋白质语言模型Prot42仅利用目标蛋白序列即可生成高亲和力结合剂
蛋白质结合剂(如抗体、抑制肽)在疾病诊断、成像分析及靶向药物递送等关键场景中发挥着不可替代的作用。传统上,高特异性蛋白质结合剂的开发高度依赖噬菌体展示、定向进化等实验技术,但这类方法普遍面临资源消耗巨大、研发周期冗长…...
UE5 学习系列(三)创建和移动物体
这篇博客是该系列的第三篇,是在之前两篇博客的基础上展开,主要介绍如何在操作界面中创建和拖动物体,这篇博客跟随的视频链接如下: B 站视频:s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...
什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...
postgresql|数据库|只读用户的创建和删除(备忘)
CREATE USER read_only WITH PASSWORD 密码 -- 连接到xxx数据库 \c xxx -- 授予对xxx数据库的只读权限 GRANT CONNECT ON DATABASE xxx TO read_only; GRANT USAGE ON SCHEMA public TO read_only; GRANT SELECT ON ALL TABLES IN SCHEMA public TO read_only; GRANT EXECUTE O…...
C# SqlSugar:依赖注入与仓储模式实践
C# SqlSugar:依赖注入与仓储模式实践 在 C# 的应用开发中,数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护,许多开发者会选择成熟的 ORM(对象关系映射)框架,SqlSugar 就是其中备受…...
CMake 从 GitHub 下载第三方库并使用
有时我们希望直接使用 GitHub 上的开源库,而不想手动下载、编译和安装。 可以利用 CMake 提供的 FetchContent 模块来实现自动下载、构建和链接第三方库。 FetchContent 命令官方文档✅ 示例代码 我们将以 fmt 这个流行的格式化库为例,演示如何: 使用 FetchContent 从 GitH…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
