当前位置: 首页 > news >正文

Pytorch笔记之回归

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、结果预测
  • 总结


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法
2、实例化模型,定义损失函数与优化器

# 继承模型
class LinearRegression(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):out = self.fc(x)return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()
2、反向传播计算梯度值:loss.backward()
3、执行参数更新:optimizer.step()
循环迭代,定期输出损失值

print('损失值')
for i in range(1001):out = model.forward(inputs)loss = mse_loss(out, target)optimizer.zero_grad()loss.backward()optimizer.step()if i % 200 == 0:print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()


总结

使用Pytorch进行训练主要的三步:
(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;
(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;
(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

相关文章:

Pytorch笔记之回归

文章目录 前言一、导入库二、数据处理三、构建模型四、迭代训练五、结果预测总结 前言 以线性回归为例,记录Pytorch的基本使用方法。 一、导入库 import numpy as np import matplotlib.pyplot as plt import torch from torch.autograd import Variable # 定义求…...

哪个证券公司可以加杠杆,淘配网是您的杠杆综合网站!

在证券市场中,投资者经常寻求提高资金杠杆以获得更高的回报。杠杆交易可以让您在不必拥有等额本金的情况下,参与更多的交易活动。然而,为了进行杠杆交易,您需要找到一家证券公司或平台,可以为您提供这种服务。本文将介…...

万字解读|怎样激活 TDengine 最高性价比?

不知不觉间,TDengine 已经 6 岁多了。在这 6 年多的时间里,我们从零开始,在一行又一行代码的淬炼下,TDengine 从 1.6 走过 2.0,终于走到如今的 3.0 时代。 自 2022 年下旬发布以来,经过我们不断地打磨优化…...

【目标检测】大图包括标签切分,并转换成txt格式

前言 遥感图像比较大,通常需要切分成小块再进行训练,之前写过一篇关于大图裁切和拼接的文章【目标检测】图像裁剪/标签可视化/图像拼接处理脚本,不过当时的工作流是先将大图切分成小图,再在小图上进行标注,于是就不考…...

gitlab登录出现的Invalid login or password问题

前提 我是在一个项目里创建的gitlab账号,想在别的项目里登录或者官网登录发现怎么都登陆不上 原因 在GitLab中,有两种不同的账号类型:项目账号和个人账号(官网账号)。 项目账号:项目账号是在特定GitLab…...

git本地创建分支并推送到远程

1. 创建本地分支并切换到该分支 比如我创建dev分支。git checkout -b相当于把两条命令git branch 分支名、git checkout分支名合成一条,来实现一条命令新建分支切换分支。 git checkout -b dev 2. 将dev分支推送到远程 -u参数与--set-upstream这一串是一个意思&am…...

手机待办事项app哪个好?

手机是日常很多人随身携带的设备,手机除了拥有通讯功能外,还能帮助大家高效管理日常工作,借助手机上的待办事项提醒APP可以快速地帮助大家规划日常事务,提高工作的效率。 过去,我也曾经在寻找一款能够将工作任务清晰罗…...

容器运行elasticsearch安装ik分词非root权限安装报错问题

有些应用默认不允许root用户运行,来确保应用的安全性,这也会导致我们使用docker run后一些操作问题,用es安装ik分词器举例(es版本8.9.0,analysis-ik版本8.9.0) 1. 容器启动elasticsearch 如挂载方式&…...

UE4游戏客户端开发进阶学习指南

前言 两年多前写过一篇入门指南,教大家在短时间内快速入门UE4的使用,在知乎被很多人收藏了。如今鸡佬使用UE快三年了,是时候更新一下进阶版本的学习指南。本文对于读者的要求: 有一定的C基础已经入门UE,能够用蓝图和…...

javaee SpringMVC 乱码问题解决

方法一 在web.xml文件中注册过滤器 <!-- 注册过滤器 设置编码 --><filter><filter-name>CharacterEncodingFilter</filter-name><filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class><init-param&…...

用ChatGPT做数据分析,提升10倍工作效率

目录 写报告分析框架报告框架指标体系设计 Excel 写报告 分析框架 拿到一个专题不知道怎么做&#xff1f;没关系&#xff0c;用ChatGPT列一下框架。 以上分析框架挺像那么回事&#xff0c;如果没思路的话&#xff0c;问问ChatGPT能起到找灵感的作用。 报告框架 报告的框架…...

【Pytorch笔记】4.梯度计算

深度之眼官方账号 - 01-04-mp4-计算图与动态图机制 前置知识&#xff1a;计算图 可以参考我的笔记&#xff1a; 【学习笔记】计算机视觉与深度学习(2.全连接神经网络) 计算图 以这棵计算图为例。这个计算图中&#xff0c;叶子节点为x和w。 import torchw torch.tensor([1.]…...

浏览器安装vue调试工具

下载扩展程序文件 下载链接&#xff1a;链接: 下载连接网盘地址&#xff0c; 提取码: 0u46&#xff0c;里面有两个crx,一个适用于vue2&#xff0c;一个适用于vue3&#xff0c;可根据vue版本选择不同的调试工具 crx安装扩展程序不成功&#xff0c;将文件改为rar文件然后解压 安装…...

C/C++学习 -- RSA算法

概述 RSA算法是一种广泛应用于数据加密与解密的非对称加密算法。它由三位数学家&#xff08;Rivest、Shamir和Adleman&#xff09;在1977年提出&#xff0c;因此得名。RSA算法的核心原理是基于大素数的数学问题的难解性&#xff0c;利用两个密钥来完成加密和解密操作。 特点 …...

基于若依ruoyi-nbcio支持flowable流程增加自定义业务表单(一)

因为需要支持自定义业务表单的相关流程&#xff0c;所以需要建立相应的关联表 1、首先先建表wf_custom_form -- ---------------------------- -- Table structure for wf_custom_form -- ---------------------------- DROP TABLE IF EXISTS wf_custom_form; CREATE TABLE wf…...

面试经典 150 题 1 —(数组 / 字符串)— 88. 合并两个有序数组

88. 合并两个有序数组 方法一&#xff1a; class Solution { public:void merge(vector<int>& nums1, int m, vector<int>& nums2, int n) {for(int i 0; i<n;i){nums1[mi] nums2[i];}sort(nums1.begin(),nums1.end());} };方法二&#xff1a; clas…...

【大数据 | 综合实践】大数据技术基础综合项目 - 基于GitHub API的数据采集与分析平台

&#x1f935;‍♂️ 个人主页: AI_magician &#x1f4e1;主页地址&#xff1a; 作者简介&#xff1a;CSDN内容合伙人&#xff0c;全栈领域优质创作者。 &#x1f468;‍&#x1f4bb;景愿&#xff1a;旨在于能和更多的热爱计算机的伙伴一起成长&#xff01;&#xff01;&…...

超高频RFID模具精细化生产管理方案

近二十年来&#xff0c;我国的模具行业经历了快速发展的阶段&#xff0c;然而&#xff0c;模具行业作为一个传统、复杂且竞争激烈的行业&#xff0c;企业往往以订单为导向&#xff0c;每个订单都需要进行新产品的开发&#xff0c;从客户需求分析、结构确定、报价、设计、物料准…...

FP-Growth算法全解析:理论基础与实战指导

目录 一、简介什么是频繁项集&#xff1f;什么是关联规则挖掘&#xff1f;FP-Growth算法与传统方法的对比Apriori算法Eclat算法 FP树&#xff1a;心脏部分 二、算法原理FP树的结构构建FP树第一步&#xff1a;扫描数据库并排序第二步&#xff1a;构建树 挖掘频繁项集优化&#x…...

Jmeter 分布式压测,你的系统能否承受高负载?

‍你可以使用 JMeter 来模拟高并发秒杀场景下的压力测试。这里有一个例子&#xff0c;它模拟了同时有 5000 个用户&#xff0c;循环 10 次的情况‍。 请求默认配置 token 配置 秒杀接口 ​结果分析 ​但是&#xff0c;实际企业中&#xff0c;这种压测方式根本不满足实际需求。下…...

【杂谈】-递归进化:人工智能的自我改进与监管挑战

递归进化&#xff1a;人工智能的自我改进与监管挑战 文章目录 递归进化&#xff1a;人工智能的自我改进与监管挑战1、自我改进型人工智能的崛起2、人工智能如何挑战人类监管&#xff1f;3、确保人工智能受控的策略4、人类在人工智能发展中的角色5、平衡自主性与控制力6、总结与…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地

借阿里云中企出海大会的东风&#xff0c;以**「云启出海&#xff0c;智联未来&#xff5c;打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办&#xff0c;现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

对WWDC 2025 Keynote 内容的预测

借助我们以往对苹果公司发展路径的深入研究经验&#xff0c;以及大语言模型的分析能力&#xff0c;我们系统梳理了多年来苹果 WWDC 主题演讲的规律。在 WWDC 2025 即将揭幕之际&#xff0c;我们让 ChatGPT 对今年的 Keynote 内容进行了一个初步预测&#xff0c;聊作存档。等到明…...

【SQL学习笔记1】增删改查+多表连接全解析(内附SQL免费在线练习工具)

可以使用Sqliteviz这个网站免费编写sql语句&#xff0c;它能够让用户直接在浏览器内练习SQL的语法&#xff0c;不需要安装任何软件。 链接如下&#xff1a; sqliteviz 注意&#xff1a; 在转写SQL语法时&#xff0c;关键字之间有一个特定的顺序&#xff0c;这个顺序会影响到…...

linux 错误码总结

1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...

相机从app启动流程

一、流程框架图 二、具体流程分析 1、得到cameralist和对应的静态信息 目录如下: 重点代码分析: 启动相机前,先要通过getCameraIdList获取camera的个数以及id,然后可以通过getCameraCharacteristics获取对应id camera的capabilities(静态信息)进行一些openCamera前的…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接&#xff1a;3403. 从盒子中找出字典序最大的字符串 I 代码如下&#xff1a; class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

ArcGIS Pro制作水平横向图例+多级标注

今天介绍下载ArcGIS Pro中如何设置水平横向图例。 之前我们介绍了ArcGIS的横向图例制作&#xff1a;ArcGIS横向、多列图例、顺序重排、符号居中、批量更改图例符号等等&#xff08;ArcGIS出图图例8大技巧&#xff09;&#xff0c;那这次我们看看ArcGIS Pro如何更加快捷的操作。…...

OPENCV形态学基础之二腐蚀

一.腐蚀的原理 (图1) 数学表达式&#xff1a;dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一&#xff0c;腐蚀跟膨胀属于反向操作&#xff0c;膨胀是把图像图像变大&#xff0c;而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...