pytorch实现单层线性回归模型
文章目录
- 简述
- 代码重构要点
- 数学模型、运行结果
- 数据构建与分批
- 模型封装
- 运行测试
简述
python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客
python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客
数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对-CSDN博客上述文章都是使用python来实现求梯度的,是为了学习原理,实际使用上,pytorch实现了自动求导,原理也是(基于计算图的)链式求导,本文还就 “单层线性回归” 问题用pytorch实现。
代码重构要点
1.nn.Moudle
torch.nn.Module
的继承、nn.Sequential
、nn.Linear
:
torch.nn — PyTorch 2.4 documentation
对于nn.Sequential
的理解可以看python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客一文代码的模型初始化与计算部分,如图:
nn.Sequential
可以说是把图中标注的代码封装起来了,并且可以放多层。
2.torch.optim
优化器
本例中使用随机梯度下降torch.optim.SGD()
。
torch.optim — PyTorch 2.4 documentation
SGD — PyTorch 2.4 documentation
3.数据构建与数据加载
data.TensorDataset
、data.DataLoader
,之前为了实现数据分批,手动实现了data_iter
,现在可以直接调用pytorch的data.DataLoader
。
对于data.DataLoader
的参数num_workers
,默认值为0,即在主线程中处理,但设置其它值时存在反而速度变慢的情况,以后再讨论。
数学模型、运行结果
y = X W + b y = XW + b y=XW+b
y为标量,X列数为2. 损失函数使用均方误差。
运行结果:
数据构建与分批
def build_data(weights, bias, num_examples): x = torch.randn(num_examples, len(weights)) y = x.matmul(weights) + bias # 给y加个噪声 y += torch.randn(1) return x, y def load_array(data_arrays, batch_size, num_workers=0, is_train=True): """构造一个PyTorch数据迭代器""" dataset = data.TensorDataset(*data_arrays) return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=is_train)
模型封装
class TorchLinearNet(torch.nn.Module): def __init__(self): super(TorchLinearNet, self).__init__() model = nn.Sequential(Linear(in_features=2, out_features=1)) self.model = model self.criterion = nn.MSELoss() def predict(self, x): return self.model(x) def loss(self, y_predict, y): return self.criterion(y_predict, y)
运行测试
if __name__ == '__main__': start = time.perf_counter() true_w1 = torch.rand(2, 1) true_b1 = torch.rand(1) x_train, y_train = build_data(true_w1, true_b1, 5000) net = TorchLinearNet() print(net) init_loss = net.loss(net.predict(x_train), y_train) loss_history = list() loss_history.append(init_loss.item()) num_epochs = 3 batch_size = 50 learning_rate = 0.01 dataloader_workers = 6 data_loader = load_array((x_train, y_train), batch_size=batch_size, is_train=True) optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) for epoch in range(num_epochs): # running_loss = 0.0 for x, y in data_loader: y_pred = net.predict(x) loss = net.loss(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() # running_loss = running_loss + loss.item() loss_history.append(loss.item()) end = time.perf_counter() print(f"运行时间(不含绘图时间):{(end - start) * 1000}毫秒\n") plt.title("pytorch实现单层线性回归模型", fontproperties="STSong") plt.xlabel("epoch") plt.ylabel("loss") plt.plot(loss_history, linestyle='dotted') plt.show() print(f'初始损失值:{init_loss}') print(f'最后一次损失值:{loss_history[-1]}\n') print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}') print(f'预测参数:{net.model.state_dict()}')
相关文章:

pytorch实现单层线性回归模型
文章目录 简述代码重构要点 数学模型、运行结果数据构建与分批模型封装运行测试 简述 python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客 python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客 数值微分…...

智能小家电能否利用亚马逊VC搭上跨境快车?——WAYLI威利跨境助力商家
智能小家电行业在全球化背景下,正迎来前所未有的发展机遇。亚马逊为品牌商和制造商提供的一站式服务平台,为智能小家电企业提供了搭乘跨境快车、拓展国际市场的绝佳机会。 首先,亚马逊VC平台能够帮助智能小家电企业简化与亚马逊的合作流程&am…...

顺丰科技25届秋季校园招聘常见问题答疑及校招网申测评笔试题型分析SHL题库Verify测评
Q:顺丰科技2025届校园招聘面向对象是? A:2025届应届毕业生,毕业时间段为2024年10月1日至2025年9月30日(不满足以上毕业时间的同学可以关注顺丰科技社会招聘或实习生招聘)。 Q:我可以投递几个岗…...
深入理解 Kibana 配置文件:一份详尽的指南
Kibana 是一个强大的数据可视化平台,它允许用户通过 Elasticsearch 轻松地探索和分析数据。Kibana 的配置文件 kibana.yml 是定制和优化 Kibana 行为的关键。在这篇博客中,我们将深入探讨 kibana.yml 文件中的各个配置项,并提供示例说明。 服…...

算法的学习笔记—链表中倒数第 K 个结点(牛客JZ22)
😀前言 在编程过程中,链表是一种常见的数据结构,它能够高效地进行插入和删除操作。然而,遍历链表并找到特定节点是一个典型的挑战,尤其是当我们需要找到链表中倒数第 K 个节点时。本文将详细介绍如何使用双指针技术来解…...

聊聊场景及场景测试
在我们进行测试过程中,有一种黑盒测试叫场景测试,我们完全是从用户的角度去理解系统,从而可以挖掘用户的隐含需求。 场景是指用户会使用这个系统来完成预定目标的所有情况的集合。 场景本身也代表了用户的需求,所以我们可以认为…...

Spring Web MVC入门(中)
1. 请求 访问不同的路径, 就是发送不同的请求. 在发送请求时, 可能会带⼀些参数, 所以学习Spring的请求, 主要 是学习如何传递参数到后端以及后端如何接收. 传递参数, 咱们主要是使⽤浏览器和Postman来模拟; 1.1 传递单个参数 接收单个参数,在Spring MV…...
Django后端架构开发:后台管理与会话技术详解
🌟 Django后端架构开发:后台管理与会话技术详解 🔹 后台管理:自定义模型类 Django的后台管理系统提供了强大的模型管理功能,你可以通过自定义模型类来控制模型在后台管理界面的显示和操作。自定义模型类通过继承admin…...
挑战Infiniband, 爆改Ethernet(2)
挑战Infiniband, 爆改Ethernet之物理层 前面说过UE为了挑战Infiniband在AI集群和HPC领域的优势地位,计划爆改以太网技术,以适应AI和HPC集群对高性能、可扩展网络的需求。正如UE联盟关于愿景的说明中宣称的:”提供一个完整的架构,通…...

Postman文件上传接口测试
接口介绍 返回示例 测试步骤 1.添加一个新请求,修改请求名,填写URL,选择请求方式 2.将剩下的media参数放在请求body里,选择form-data,选择key右边的类型为file类型,就会出现选择文件的按钮Select Files&a…...

stm32入门学习14-电源控制
有时候我们的程序中有些触发执行条件,有时这些触发频率很少,我们的程序就一直在循环,这样就很浪费电,我们可以通过PWR电源控制来实现低功耗模式,即只有在触发时才执行程序,其余时间可以关闭一些没必要的设备…...

[C++][opencv]基于opencv实现photoshop算法色相和饱和度调整
【测试环境】 vs2019 opencv4.8.0 【效果演示】 【核心实现代码】 HSL.hpp #ifndef OPENCV2_PS_HSL_HPP_ #define OPENCV2_PS_HSL_HPP_#include "opencv2/core.hpp" using namespace cv;namespace cv {enum HSL_COLOR {HSL_ALL,HSL_RED,HSL_YELLOW,HSL_GREEN,HS…...
Github 2024-08-16Java开源项目日报 Top10
根据Github Trendings的统计,今日(2024-08-16统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目10TypeScript项目1Ruby项目1Apache Dubbo: 高性能的Java开源RPC框架 创建周期:4441 天开发语言:Java协议类型:Apache License 2.0St…...

AI学习记录 - torch 的 matmul和dot的关联,也就是点乘和点积的联系
有用大佬们点点赞 1、两个一维向量点积 ,求 词A 与 词A 之间的关联度 2、两个词向量之间求关联度,求 : 词A 与 词A 的关联度 5 词A 与 词B 的关联度 11 词B 与 词A 的关联度 11 词B 与 词B 的关联度 25 刚刚好和矩阵乘法符合: 3、什么是…...

leetcode 885. Spiral Matrix III
题目链接 You start at the cell (rStart, cStart) of an rows x cols grid facing east. The northwest corner is at the first row and column in the grid, and the southeast corner is at the last row and column. You will walk in a clockwise spiral shape to visi…...

mysql windows安装与远程连接配置
安装包在主页资源中 一、安装(此安装教程为“mysql-installer-community-5.7.41.0.msi”安装教程,安装到win10环境) 保持默认选项,点击”Next“。 点开第一行加号展开一路展开找到“MySQL Server 5,7,41 - X64”点击选中点击一下中间只想右侧的箭头看到…...

子网掩码是什么以及子网掩码相关计算
子网掩码 (Subnet Mask) 又称网络掩码 (Netmask),告知主机或路由设备,地址的哪一部分是网络号,包括子网的网络号部分,哪一部分是主机号部分。 子网掩码使用与IP地址相同的编址格式,即32 bit—4个8位组的32位长格式。…...

仿RabbitMQ实现消息队列
前言:本项目是仿照RabbitMQ并基于SpringBoot Mybatis SQLite3实现的消息队列,该项目实现了MQ的核心功能:生产者、消费者、中间人、发布、订阅等。 源码链接:仿Rabbit MQ实现消息队列 目录 前言:本项目是仿照Rabbi…...

SpringBoot教程(二十三) | SpringBoot实现分布式定时任务之xxl-job
SpringBoot教程(二十三) | SpringBoot实现分布式定时任务之xxl-job 简介一、前置条件:需要搭建调度中心1、先下载调度中心源码2、修改配置文件3、启动项目4、进行访问5、打包部署(上正式) 二、SpringBoot集成Xxl-Job1.…...
微前端架构的数据持久化策略与实践
微前端架构通过将一个大型前端应用拆分成多个小型、自治的子应用,提升了开发效率和应用的可维护性。然而,数据持久化作为应用的基础需求,在微前端架构中实现起来面临着一些挑战。本文将详细介绍在微前端架构下实现数据持久化的策略、技术和最…...
KubeSphere 容器平台高可用:环境搭建与可视化操作指南
Linux_k8s篇 欢迎来到Linux的世界,看笔记好好学多敲多打,每个人都是大神! 题目:KubeSphere 容器平台高可用:环境搭建与可视化操作指南 版本号: 1.0,0 作者: 老王要学习 日期: 2025.06.05 适用环境: Ubuntu22 文档说…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型
摘要 拍照搜题系统采用“三层管道(多模态 OCR → 语义检索 → 答案渲染)、两级检索(倒排 BM25 向量 HNSW)并以大语言模型兜底”的整体框架: 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后,分别用…...

【Python】 -- 趣味代码 - 小恐龙游戏
文章目录 文章目录 00 小恐龙游戏程序设计框架代码结构和功能游戏流程总结01 小恐龙游戏程序设计02 百度网盘地址00 小恐龙游戏程序设计框架 这段代码是一个基于 Pygame 的简易跑酷游戏的完整实现,玩家控制一个角色(龙)躲避障碍物(仙人掌和乌鸦)。以下是代码的详细介绍:…...

STM32标准库-DMA直接存储器存取
文章目录 一、DMA1.1简介1.2存储器映像1.3DMA框图1.4DMA基本结构1.5DMA请求1.6数据宽度与对齐1.7数据转运DMA1.8ADC扫描模式DMA 二、数据转运DMA2.1接线图2.2代码2.3相关API 一、DMA 1.1简介 DMA(Direct Memory Access)直接存储器存取 DMA可以提供外设…...

从零开始打造 OpenSTLinux 6.6 Yocto 系统(基于STM32CubeMX)(九)
设备树移植 和uboot设备树修改的内容同步到kernel将设备树stm32mp157d-stm32mp157daa1-mx.dts复制到内核源码目录下 源码修改及编译 修改arch/arm/boot/dts/st/Makefile,新增设备树编译 stm32mp157f-ev1-m4-examples.dtb \stm32mp157d-stm32mp157daa1-mx.dtb修改…...

(转)什么是DockerCompose?它有什么作用?
一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用,而无需手动一个个创建和运行容器。 Compose文件是一个文本文件,通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...
Rapidio门铃消息FIFO溢出机制
关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
libfmt: 现代C++的格式化工具库介绍与酷炫功能
libfmt: 现代C的格式化工具库介绍与酷炫功能 libfmt 是一个开源的C格式化库,提供了高效、安全的文本格式化功能,是C20中引入的std::format的基础实现。它比传统的printf和iostream更安全、更灵活、性能更好。 基本介绍 主要特点 类型安全:…...
Python 高效图像帧提取与视频编码:实战指南
Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...