当前位置: 首页 > news >正文

pytorch实现单层线性回归模型

文章目录

    • 简述
      • 代码重构要点
    • 数学模型、运行结果
    • 数据构建与分批
    • 模型封装
    • 运行测试

简述

python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客
python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客
数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对-CSDN博客

上述文章都是使用python来实现求梯度的,是为了学习原理,实际使用上,pytorch实现了自动求导,原理也是(基于计算图的)链式求导,本文还就 “单层线性回归” 问题用pytorch实现。

代码重构要点

1.nn.Moudle

torch.nn.Module的继承、nn.Sequentialnn.Linear
torch.nn — PyTorch 2.4 documentation

对于nn.Sequential的理解可以看python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客一文代码的模型初始化与计算部分,如图:

在这里插入图片描述

nn.Sequential可以说是把图中标注的代码封装起来了,并且可以放多层。

2.torch.optim优化器

本例中使用随机梯度下降torch.optim.SGD()
torch.optim — PyTorch 2.4 documentation
SGD — PyTorch 2.4 documentation

3.数据构建与数据加载

data.TensorDatasetdata.DataLoader,之前为了实现数据分批,手动实现了data_iter,现在可以直接调用pytorch的data.DataLoader

对于data.DataLoader的参数num_workers,默认值为0,即在主线程中处理,但设置其它值时存在反而速度变慢的情况,以后再讨论。

数学模型、运行结果

y = X W + b y = XW + b y=XW+b

y为标量,X列数为2. 损失函数使用均方误差。

运行结果:

在这里插入图片描述

在这里插入图片描述

数据构建与分批

def build_data(weights, bias, num_examples):  x = torch.randn(num_examples, len(weights))  y = x.matmul(weights) + bias  # 给y加个噪声  y += torch.randn(1)  return x, y  def load_array(data_arrays, batch_size, num_workers=0, is_train=True):  """构造一个PyTorch数据迭代器"""  dataset = data.TensorDataset(*data_arrays)  return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=is_train)

模型封装

class TorchLinearNet(torch.nn.Module):  def __init__(self):  super(TorchLinearNet, self).__init__()  model = nn.Sequential(Linear(in_features=2, out_features=1))  self.model = model  self.criterion = nn.MSELoss()  def predict(self, x):  return self.model(x)  def loss(self, y_predict, y):  return self.criterion(y_predict, y)

运行测试

if __name__ == '__main__':  start = time.perf_counter()  true_w1 = torch.rand(2, 1)  true_b1 = torch.rand(1)  x_train, y_train = build_data(true_w1, true_b1, 5000)  net = TorchLinearNet()  print(net)  init_loss = net.loss(net.predict(x_train), y_train)  loss_history = list()  loss_history.append(init_loss.item())  num_epochs = 3  batch_size = 50  learning_rate = 0.01  dataloader_workers = 6  data_loader = load_array((x_train, y_train), batch_size=batch_size, is_train=True)  optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)  for epoch in range(num_epochs):  # running_loss = 0.0  for x, y in data_loader:  y_pred = net.predict(x)  loss = net.loss(y_pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  # running_loss = running_loss + loss.item()  loss_history.append(loss.item())  end = time.perf_counter()  print(f"运行时间(不含绘图时间):{(end - start) * 1000}毫秒\n")  plt.title("pytorch实现单层线性回归模型", fontproperties="STSong")  plt.xlabel("epoch")  plt.ylabel("loss")  plt.plot(loss_history, linestyle='dotted')  plt.show()  print(f'初始损失值:{init_loss}')  print(f'最后一次损失值:{loss_history[-1]}\n')  print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}')  print(f'预测参数:{net.model.state_dict()}')

相关文章:

pytorch实现单层线性回归模型

文章目录 简述代码重构要点 数学模型、运行结果数据构建与分批模型封装运行测试 简述 python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客 python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客 数值微分…...

智能小家电能否利用亚马逊VC搭上跨境快车?——WAYLI威利跨境助力商家

智能小家电行业在全球化背景下,正迎来前所未有的发展机遇。亚马逊为品牌商和制造商提供的一站式服务平台,为智能小家电企业提供了搭乘跨境快车、拓展国际市场的绝佳机会。 首先,亚马逊VC平台能够帮助智能小家电企业简化与亚马逊的合作流程&am…...

顺丰科技25届秋季校园招聘常见问题答疑及校招网申测评笔试题型分析SHL题库Verify测评

Q:顺丰科技2025届校园招聘面向对象是? A:2025届应届毕业生,毕业时间段为2024年10月1日至2025年9月30日(不满足以上毕业时间的同学可以关注顺丰科技社会招聘或实习生招聘)。 Q:我可以投递几个岗…...

深入理解 Kibana 配置文件:一份详尽的指南

Kibana 是一个强大的数据可视化平台,它允许用户通过 Elasticsearch 轻松地探索和分析数据。Kibana 的配置文件 kibana.yml 是定制和优化 Kibana 行为的关键。在这篇博客中,我们将深入探讨 kibana.yml 文件中的各个配置项,并提供示例说明。 服…...

算法的学习笔记—链表中倒数第 K 个结点(牛客JZ22)

😀前言 在编程过程中,链表是一种常见的数据结构,它能够高效地进行插入和删除操作。然而,遍历链表并找到特定节点是一个典型的挑战,尤其是当我们需要找到链表中倒数第 K 个节点时。本文将详细介绍如何使用双指针技术来解…...

聊聊场景及场景测试

在我们进行测试过程中,有一种黑盒测试叫场景测试,我们完全是从用户的角度去理解系统,从而可以挖掘用户的隐含需求。 场景是指用户会使用这个系统来完成预定目标的所有情况的集合。 场景本身也代表了用户的需求,所以我们可以认为…...

Spring Web MVC入门(中)

1. 请求 访问不同的路径, 就是发送不同的请求. 在发送请求时, 可能会带⼀些参数, 所以学习Spring的请求, 主要 是学习如何传递参数到后端以及后端如何接收. 传递参数, 咱们主要是使⽤浏览器和Postman来模拟; 1.1 传递单个参数 接收单个参数,在Spring MV…...

Django后端架构开发:后台管理与会话技术详解

🌟 Django后端架构开发:后台管理与会话技术详解 🔹 后台管理:自定义模型类 Django的后台管理系统提供了强大的模型管理功能,你可以通过自定义模型类来控制模型在后台管理界面的显示和操作。自定义模型类通过继承admin…...

挑战Infiniband, 爆改Ethernet(2)

挑战Infiniband, 爆改Ethernet之物理层 前面说过UE为了挑战Infiniband在AI集群和HPC领域的优势地位,计划爆改以太网技术,以适应AI和HPC集群对高性能、可扩展网络的需求。正如UE联盟关于愿景的说明中宣称的:”提供一个完整的架构,通…...

Postman文件上传接口测试

接口介绍 返回示例 测试步骤 1.添加一个新请求,修改请求名,填写URL,选择请求方式 2.将剩下的media参数放在请求body里,选择form-data,选择key右边的类型为file类型,就会出现选择文件的按钮Select Files&a…...

stm32入门学习14-电源控制

有时候我们的程序中有些触发执行条件,有时这些触发频率很少,我们的程序就一直在循环,这样就很浪费电,我们可以通过PWR电源控制来实现低功耗模式,即只有在触发时才执行程序,其余时间可以关闭一些没必要的设备…...

[C++][opencv]基于opencv实现photoshop算法色相和饱和度调整

【测试环境】 vs2019 opencv4.8.0 【效果演示】 【核心实现代码】 HSL.hpp #ifndef OPENCV2_PS_HSL_HPP_ #define OPENCV2_PS_HSL_HPP_#include "opencv2/core.hpp" using namespace cv;namespace cv {enum HSL_COLOR {HSL_ALL,HSL_RED,HSL_YELLOW,HSL_GREEN,HS…...

Github 2024-08-16Java开源项目日报 Top10

根据Github Trendings的统计,今日(2024-08-16统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目10TypeScript项目1Ruby项目1Apache Dubbo: 高性能的Java开源RPC框架 创建周期:4441 天开发语言:Java协议类型:Apache License 2.0St…...

AI学习记录 - torch 的 matmul和dot的关联,也就是点乘和点积的联系

有用大佬们点点赞 1、两个一维向量点积 ,求 词A 与 词A 之间的关联度 2、两个词向量之间求关联度,求 : 词A 与 词A 的关联度 5 词A 与 词B 的关联度 11 词B 与 词A 的关联度 11 词B 与 词B 的关联度 25 刚刚好和矩阵乘法符合: 3、什么是…...

leetcode 885. Spiral Matrix III

题目链接 You start at the cell (rStart, cStart) of an rows x cols grid facing east. The northwest corner is at the first row and column in the grid, and the southeast corner is at the last row and column. You will walk in a clockwise spiral shape to visi…...

mysql windows安装与远程连接配置

安装包在主页资源中 一、安装(此安装教程为“mysql-installer-community-5.7.41.0.msi”安装教程,安装到win10环境) 保持默认选项,点击”Next“。 点开第一行加号展开一路展开找到“MySQL Server 5,7,41 - X64”点击选中点击一下中间只想右侧的箭头看到…...

子网掩码是什么以及子网掩码相关计算

子网掩码 (Subnet Mask) 又称网络掩码 (Netmask),告知主机或路由设备,地址的哪一部分是网络号,包括子网的网络号部分,哪一部分是主机号部分。 子网掩码使用与IP地址相同的编址格式,即32 bit—4个8位组的32位长格式。…...

仿RabbitMQ实现消息队列

前言:本项目是仿照RabbitMQ并基于SpringBoot Mybatis SQLite3实现的消息队列,该项目实现了MQ的核心功能:生产者、消费者、中间人、发布、订阅等。 源码链接:仿Rabbit MQ实现消息队列 目录 前言:本项目是仿照Rabbi…...

SpringBoot教程(二十三) | SpringBoot实现分布式定时任务之xxl-job

SpringBoot教程(二十三) | SpringBoot实现分布式定时任务之xxl-job 简介一、前置条件:需要搭建调度中心1、先下载调度中心源码2、修改配置文件3、启动项目4、进行访问5、打包部署(上正式) 二、SpringBoot集成Xxl-Job1.…...

微前端架构的数据持久化策略与实践

微前端架构通过将一个大型前端应用拆分成多个小型、自治的子应用,提升了开发效率和应用的可维护性。然而,数据持久化作为应用的基础需求,在微前端架构中实现起来面临着一些挑战。本文将详细介绍在微前端架构下实现数据持久化的策略、技术和最…...

接口测试中缓存处理策略

在接口测试中,缓存处理策略是一个关键环节,直接影响测试结果的准确性和可靠性。合理的缓存处理策略能够确保测试环境的一致性,避免因缓存数据导致的测试偏差。以下是接口测试中常见的缓存处理策略及其详细说明: 一、缓存处理的核…...

JavaSec-RCE

简介 RCE(Remote Code Execution),可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景:Groovy代码注入 Groovy是一种基于JVM的动态语言,语法简洁,支持闭包、动态类型和Java互操作性&#xff0c…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...

IGP(Interior Gateway Protocol,内部网关协议)

IGP(Interior Gateway Protocol,内部网关协议) 是一种用于在一个自治系统(AS)内部传递路由信息的路由协议,主要用于在一个组织或机构的内部网络中决定数据包的最佳路径。与用于自治系统之间通信的 EGP&…...

微信小程序 - 手机震动

一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注&#xff1a;文档 https://developers.weixin.qq…...

【算法训练营Day07】字符串part1

文章目录 反转字符串反转字符串II替换数字 反转字符串 题目链接&#xff1a;344. 反转字符串 双指针法&#xff0c;两个指针的元素直接调转即可 class Solution {public void reverseString(char[] s) {int head 0;int end s.length - 1;while(head < end) {char temp …...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

GC1808高性能24位立体声音频ADC芯片解析

1. 芯片概述 GC1808是一款24位立体声音频模数转换器&#xff08;ADC&#xff09;&#xff0c;支持8kHz~96kHz采样率&#xff0c;集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器&#xff0c;适用于高保真音频采集场景。 2. 核心特性 高精度&#xff1a;24位分辨率&#xff0c…...

并发编程 - go版

1.并发编程基础概念 进程和线程 A. 进程是程序在操作系统中的一次执行过程&#xff0c;系统进行资源分配和调度的一个独立单位。B. 线程是进程的一个执行实体,是CPU调度和分派的基本单位,它是比进程更小的能独立运行的基本单位。C.一个进程可以创建和撤销多个线程;同一个进程中…...

uniapp 字符包含的相关方法

在uniapp中&#xff0c;如果你想检查一个字符串是否包含另一个子字符串&#xff0c;你可以使用JavaScript中的includes()方法或者indexOf()方法。这两种方法都可以达到目的&#xff0c;但它们在处理方式和返回值上有所不同。 使用includes()方法 includes()方法用于判断一个字…...