当前位置: 首页 > news >正文

使用pytorch实现线性回归(很基础模型搭建详解)

使用pytorch实现线性回归

步骤:

        1.prepare dataset
        2.design model using Class 目的是为了前向传播forward,即计算y hat(预测值)
        3.Construct loss and optimizer (using pytorch API) 其中计算loss是为了进行反向传播,optimizer是为了更新梯度
        4.Train Cycle
import torch
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

1.nn.Module: 是所有神经网络单元(neural network modules)的基类

nn:neural network

2.构造函数__init__():是用来初始化对象的时候默认调用的函数

3.class torch.nn.Linear(in_features,out_features,bias=True) y = Ax+b

in_features   : size of each input sample
out_features : size of each output sample
# 实例化模型
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造self.linear = torch.nn.Linear(1,1) # 实例化类,构造对象,包含了权重和偏置# Linear也是继承自module的,也能进行反向传播# nn:neural networkdef forward(self,x):y_pred = self.linear(x)return y_predmodel = LinearModel()

1.class torch.nn.MSELoss(size_average=True, reduce=True)

size_average:是否计算损失均值
reduce:

2.class torch.optim.SGD(model.parameters(),lr=<object object》,momentum=0,dampening=0,weight_decay=0,nesterov=False)

model.parameters(): 告诉优化器对哪些Tensor进行优化,检查所有成员
lr:learning rate 学习率

3.torch.nn.MSELoss也跟torch.nn.Module有关,参与计算图的构建,torch.optim.SGD与torch.nn.Module无关,不参与构建计算图。

# 定义损失函数
criterion = torch.nn.MSELoss(size_average=False)
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
# 训练
for epoch in range(100):y_pred = model(x_data) # 计算y hatloss = criterion(y_pred,y_data) # 计算lossprint(epoch,loss.item())# optimizer.zero_grad() # 梯度归零loss.backward() # 反向传播,计算梯度optimizer.step() # update 参数,即更新w和b的值optimizer.zero_grad() # 梯度归零print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred',y_test.data)

1.第三行直接使用对象调用函数(可调用对象)y_pred = model(x_data):¶

Module实现了魔法函数__call__(),call()里面有一条语句是要调用forward(),因此新写类中需要重写forward()覆盖父类的forward()
call() 函数的作用四可以直接在对象后面加(),例如实例化的model对象,和实例化的linear对象
调用此方法:super(LinearModel, self).init() 实例化了父类中所有方法

2.每一次epoch的训练过程,总结就是

①前向传播,求y hat (输入的预测值)
②根据y_hat和y_label(y_data)计算loss
③反向传播 backward (计算梯度)
④根据梯度,更新参数
⑤梯度清零

3.optimizer.zero_grad()

因为grad在 反向传播 的过程中是累加的,也就是说上一次反向传播的结果会对下一次的反向传播的结果造成影响,则意味着每一次运行反向传播,梯度都会累加之前的梯度,所以一般在反向传播之前需要把梯度清零。

本次代码实现的是批量处理数据

相关文章:

使用pytorch实现线性回归(很基础模型搭建详解)

使用pytorch实现线性回归 步骤&#xff1a; 1.prepare dataset 2.design model using Class 目的是为了前向传播forward&#xff0c;即计算y hat&#xff08;预测值&#xff09; 3.Construct loss and optimizer (using pytorch API) 其中计算loss是为了进行反向传播&#xff0…...

【力扣100】【好题】322.零钱兑换 || 01背包完全背包

添加链接描述 思路&#xff1a; dp[j]数组表示的是在金额达到 j 的时候所需要的最小硬币数金额&#xff1a;背包容量&#xff0c;每个硬币的个数都为1&#xff1a;背包中物品的价值&#xff0c;硬币面额&#xff1a;物品重量dp[j]min(dp[j],dp[j-coin]1) class Solution:def …...

工单管理系统建设方案

1.1 系统概述 1.1.1 需求描述 1.1.2 需求分析 1.1.3 重难点分析 1.1.4 重难点解决措施 1.2 系统架构设计 1.2.1 系统架构图 1.2.2 关键技术 1.3 系统功能设计 1.3.1 工单创建 1.3.2 工单管理 1.3.3 工单处理 1.3.4 工单催办 1.3.5 工单归档 1.3.6 工单统计 软件项目全套资料获取…...

什么是农业四情监测设备?

【TH-Q2】智慧农业四情监测设备是一种高科技的农田监测工具&#xff0c;旨在实时监测和管理农田中的土壤墒情、作物生长、病虫害以及气象条件。具体来说&#xff0c;它主要包括以下组成部分&#xff1a; 气象站&#xff1a;用于监测气温、湿度、风速等气象数据&#xff0c;为农…...

Java面试题:请解释Java并发工具包中的主要组件及其应用场景,请描述一个使用Java并发框架(如Fork/Join框架)解决实际问题的编程实操问题

文章标题&#xff1a;《Java内存模型深入解析与多线程并发工具类应用》 引言&#xff1a; 在Java的世界里&#xff0c;掌握内存模型和多线程并发是高级开发者的必备技能。Java内存模型&#xff08;JMM&#xff09;和多线程并发工具包为开发者提供了强大的能力&#xff0c;同时…...

boot应用打包

1.创建项目 2.编写 3.native构建 报错&#xff1a; [WARNING] native:build goal is deprecated. Use native:compile-no-fork instead. [INFO] Found GraalVM installation from GRAALVM_HOME variable. [INFO] Executing: S:\Coding\graalvm-jdk-17_windows-x64_bin\graalv…...

探索数据可视化:Matplotlib 多图布局

多图布局 子视图 import numpy as np import matplotlib.pyplot as pltx np.linspace(0,2*np.pi)plt.figure(figsize(9,6))# 创建子视图 # subplot(2,1,1)表示将当前图形分割成 2 行 1 列的子图网格&#xff0c;并在第 1 个子图位置绘制图形 ax plt.subplot(2,1,1) ax.plot…...

springboot262基于spring boot的小型诊疗预约平台的设计与开发

小型诊疗预约平台 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本小型诊疗预约平台就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理…...

Java项目修改源码jar文件(无需反编译)

文章目录 应用场景实现方案实现原理注意事项 应用场景 在项目中用了第三方的jar包&#xff0c;但是jar包内某个类不符合项目业务需求&#xff0c;需要修改第三方jar包源码文件内容。 实现方案 首先我们尝试直接修改jar包源码文件内容时&#xff0c;页面上会提示文件是只读的&a…...

java使用BatchPoints批量写入Influxdb

前言 使用时序数据库influxdb时&#xff0c;我们经常需要写入大量的数据。而单单使用influxDB.write&#xff08;Point&#xff09;进行单条写入时&#xff0c;速度过慢&#xff0c;无法支撑时序数据大量写入的速度。 所以我们需要采用批量的方式进行存储&#xff0c;增加写入…...

Java 集合类的高级特性介绍

在 Java 编程中&#xff0c;了解集合类的高级特性对于编写高效和可维护的代码至关重要。以下是一些你应该知道的 Java 集合类的高级特性&#xff0c;以及简单的例子来说明它们的用法。 1. 迭代器&#xff08;Iterators&#xff09;和列表迭代器&#xff08;ListIterators&#…...

使用Docker搭建Caddy

使用Docker搭建Caddy&#xff0c;可以快速部署一个轻量级的、支持自动HTTPS的web服务器。下面将分别介绍使用Docker CLI和Docker Compose两种方式来搭建Caddy服务器&#xff0c;并给出配置文件示例以及参数解释。 使用Docker CLI搭建Caddy 首先&#xff0c;确保你的系统上已安…...

synchronized是重量级锁???

synchronized作为Java程序员最常用同步工具&#xff0c;很多人却对它的用法和实现原理一知半解&#xff0c;以至于还有不少人认为synchronized是重量级锁&#xff0c;性能较差&#xff0c;尽量少用。 但不可否认的是synchronized依然是并发首选工具&#xff0c;连volatile、CA…...

Linux之线程控制

目录 一、POSIX线程库 二、线程的创建 三、线程等待 四、线程终止 五、分离线程 六、线程ID&#xff1a;pthread_t 1、获取线程ID 2、pthread_t 七、线程局部存储&#xff1a;__thread 一、POSIX线程库 由于Linux下的线程并没有独立特有的结构&#xff0c;所以Linux并…...

Python实现线性查找算法

Python实现线性查找算法 以下是使用 Python 实现线性查找算法的示例代码&#xff1a; def linear_search(arr, target):"""线性查找算法:param arr: 要搜索的数组:param target: 目标值:return: 如果找到目标值&#xff0c;返回其索引&#xff1b;否则返回 -1…...

总结Redis的原理

一、为什么要使用Redis 缓解数据库访问压力mysql读请求进行磁盘I/O速度慢&#xff0c;给数据库加Redis缓存&#xff08;参考CPU缓存&#xff09;&#xff0c;将数据缓存在内存中&#xff0c;省略了I/O操作 二、Redis数据管理 2.1 redis数据的删除 定时删除惰性删除内存淘汰…...

计算机设计大赛 疲劳驾驶检测系统 python

文章目录 0 前言1 课题背景2 Dlib人脸识别2.1 简介2.2 Dlib优点2.3 相关代码2.4 人脸数据库2.5 人脸录入加识别效果 3 疲劳检测算法3.1 眼睛检测算法3.2 打哈欠检测算法3.3 点头检测算法 4 PyQt54.1 简介4.2相关界面代码 5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#x…...

什么是智慧公厕?智慧公厕的应用价值有哪些?

在现代社会&#xff0c;城市的发展与人民生活质量息息相关。作为城市基础设施中的重要一环&#xff0c;公共厕所的建设及管理一直备受关注。智慧公厕作为一种公共厕所使用、运行、管理的综合应用解决方案&#xff0c;正逐渐在智慧城市的建设中崭露头角。那么&#xff0c;智慧公…...

VideoDubber时长可控的视频配音方法

本次分享由中国人民大学、微软亚洲研究院联合投稿于AAAI 2023的一篇专门为视频配音任务定制的机器翻译的工作《VideoDubber: Machine Translation with Speech-Aware Length Control for Video Dubbing》。这个工作将电影或电视节目中的原始语音翻译成目标语言。 论文地址&…...

中科数安|公司办公终端、电脑文件数据 \ 资料防泄密系统

#中科数安# 中科数安是一家专注于信息安全技术与产品研发的高新技术企业&#xff0c;其提供的公司办公终端、电脑文件数据及资料防泄密系统&#xff08;也称为终端数据防泄漏系统或简称DLP系统&#xff09;主要服务于企业对内部敏感信息的安全管理需求。 www.weaem.com 该系统…...

云原生核心技术 (7/12): K8s 核心概念白话解读(上):Pod 和 Deployment 究竟是什么?

大家好&#xff0c;欢迎来到《云原生核心技术》系列的第七篇&#xff01; 在上一篇&#xff0c;我们成功地使用 Minikube 或 kind 在自己的电脑上搭建起了一个迷你但功能完备的 Kubernetes 集群。现在&#xff0c;我们就像一个拥有了一块崭新数字土地的农场主&#xff0c;是时…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中&#xff0c;我们已经大致实现了rpc服务端的各项功能代…...

椭圆曲线密码学(ECC)

一、ECC算法概述 椭圆曲线密码学&#xff08;Elliptic Curve Cryptography&#xff09;是基于椭圆曲线数学理论的公钥密码系统&#xff0c;由Neal Koblitz和Victor Miller在1985年独立提出。相比RSA&#xff0c;ECC在相同安全强度下密钥更短&#xff08;256位ECC ≈ 3072位RSA…...

简易版抽奖活动的设计技术方案

1.前言 本技术方案旨在设计一套完整且可靠的抽奖活动逻辑,确保抽奖活动能够公平、公正、公开地进行,同时满足高并发访问、数据安全存储与高效处理等需求,为用户提供流畅的抽奖体验,助力业务顺利开展。本方案将涵盖抽奖活动的整体架构设计、核心流程逻辑、关键功能实现以及…...

React Native 导航系统实战(React Navigation)

导航系统实战&#xff08;React Navigation&#xff09; React Navigation 是 React Native 应用中最常用的导航库之一&#xff0c;它提供了多种导航模式&#xff0c;如堆栈导航&#xff08;Stack Navigator&#xff09;、标签导航&#xff08;Tab Navigator&#xff09;和抽屉…...

python如何将word的doc另存为docx

将 DOCX 文件另存为 DOCX 格式&#xff08;Python 实现&#xff09; 在 Python 中&#xff0c;你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是&#xff0c;.doc 是旧的 Word 格式&#xff0c;而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成&#xff0c;用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机&#xff1a; ​onCreate()​​ ​调用时机​&#xff1a;Activity 首次创建时调用。​…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具&#xff0c;可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件&#xff0c;也不需要在线上传文件&#xff0c;保护您的隐私。 工具截图 主要特点 &#x1f680; 快速转换&#xff1a;本地转换&#xff0c;无需等待上…...

解决:Android studio 编译后报错\app\src\main\cpp\CMakeLists.txt‘ to exist

现象&#xff1a; android studio报错&#xff1a; [CXX1409] D:\GitLab\xxxxx\app.cxx\Debug\3f3w4y1i\arm64-v8a\android_gradle_build.json : expected buildFiles file ‘D:\GitLab\xxxxx\app\src\main\cpp\CMakeLists.txt’ to exist 解决&#xff1a; 不要动CMakeLists.…...

在 Spring Boot 项目里,MYSQL中json类型字段使用

前言&#xff1a; 因为程序特殊需求导致&#xff0c;需要mysql数据库存储json类型数据&#xff0c;因此记录一下使用流程 1.java实体中新增字段 private List<User> users 2.增加mybatis-plus注解 TableField(typeHandler FastjsonTypeHandler.class) private Lis…...