线形回归与小批量梯度下降实例
1、准备数据集
import numpy as np
import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader
from torch.utils.data import TensorDataset#########################################################################
#################准备若干个随机的x和y#####################################
#########################################################################
np.random.seed(100) #使用random.seed,设置一个固定的随机种子,
data_size = 150 # 数据集大小
x_range = 5 # x的范围
iteration_count = 100 # 迭代次数# np.random.rand 是 NumPy 库中的一个函数,用于生成一个给定形状的数组,
# 数组中的元素是从一个均匀分布的样本中抽取的,这个均匀分布是在半开区间 [0, 1) 上。
# 这意味着产生的随机数将大于等于0且小于1。#随机生成data_size个横坐标x,范围在0到x_range之间
x=x_range * np.random.rand(data_size,1)#生成带有噪音y数据,基本分布在y=2x+6的附近
y=2*x + 6 + np.random.randn(data_size,1)*0.3plt.scatter(x,y,marker='x',color='green')#########################################################################
#################将训练数据转为张量#######################################
#########################################################################
#将训练数据转为张量
tensorX = torch.from_numpy(x).float()
tensorY = torch.from_numpy(y).float()#使用TensorDataset,将tensorX和tensorY组成训练集
dataset = TensorDataset(tensorX,tensorY)#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,batch_size = 20, #每一个小批量的数据规模是20shuffle =True ) #随机打乱数据的顺序
print("dataloader len =%d" %(len(dataloader)))for index,(data,label) in enumerate(dataloader):print("index=%d num = %d"%(index,len(data)))
2、线性回归模型的训练思路
2.1 初始化参数
设置是模型参数:权重w
和偏置b
初始化为随机值
设置 requires_grad=True
,PyTorch 将记录这些张量的操作历史,用于后续的自动求导
2.2 循环训练(epoch)
epoch
变量用于控制整个训练过程的迭代轮数
在机器学习和深度学习中,“epoch” 是一个常用的术语,指的是在整个数据集上完整地运行一次(即正向传播和反向传播)训练算法的过程。
定义:一个 epoch 是指训练过程中,训练集中每个样本都被使用过一次来更新模型的权重。
训练过程:在训练一个模型时,通常会将数据集分成多个批次(batches)。每个批次包含一定数量的样本。一个 epoch 完成意味着所有批次都已经过模型处理。
迭代与epoch:在一个 epoch 内,模型可能会多次迭代,每次迭代处理一个批次的数据。因此,一个 epoch 包含多个迭代(iterations)。
目的:通过多个 epochs 的训练,模型可以逐渐学习数据集中的模式,从而提高其性能。
数量:训练一个模型所需的 epochs 数量取决于多种因素,包括数据集的大小、模型的复杂度以及问题的难度。有时可能只需要几个 epochs,而有时可能需要数百甚至数千个 epochs。
监控:在训练过程中,通常会监控每个 epoch 的性能指标(如损失函数的值或准确率),以评估模型的学习进度。
过拟合与欠拟合:如果训练过多的 epochs,模型可能会过拟合(即模型学习到了数据中的噪声而非潜在的模式),而训练不足的 epochs 则可能导致欠拟合(即模型未能捕捉到数据中的关键模式)。
2.3 数据加载
内层循环通过 dataloader
遍历训练数据集的小批量数据。dataloader
是一个数据加载器,通常由 DataLoader
类创建,用于批量加载数据。
2.4 前向传播
假设 tensorX
是当前批次的数据
tensorY
是对应的真实标签
使用当前参数 w
和 b
计算预测值 h=w*tensorX
+b。
2.5 计算损失
计算预测值 h
和真实值 tensorY
之间的均方误差(MSE),并保存到 loss
loss = torch.mean((h - tensorY) ** 2)
2.6 反向传播
调用 loss.backward()
进行反向传播,计算损失关于参数 w
和 b
的梯度
设置了 requires_grad=True
,PyTorch 将记录这些张量的操作历史并自动求导
2.7 更新参数
使用梯度下降算法更新参数 w
和 b
。学习率设置为0.01
w.data -= 0.01 * w.grad.data
b.data -= 0.01 * b.grad.data
沿着当前小批量计算的得到的梯度(导数)更新w和b
如果导数为0,则w、b保存不变
2.8 梯度清零
在每次迭代后,需要清空参数的梯度信息,以便下一次迭代计算
3、线性回归模型的实现
# 待送代的参数为w和b
w = torch.randn(1,requires_grad=True)
b = torch.randn(1,requires_grad=True)#进入模型的循环迭代
for epoch in range(1,iteration_count):#代表了整个训练数据集的迭代轮数# 在一个迭代轮次中,以小批量的方式,使用dataloader对数据# batch_index表示当前遍历的批次# data和label表示这个批次的训练数据和标记for batch_index,(data, label)in enumerate(dataloader):h = tensorX * w + b #计算当前直线的预测值,保存到h#计算预测值h和真实值y之间的均方误差,保存到loss中loss=torch.mean((h-tensorY)**2)#计算代价1oss关于参数w和b的偏导数,设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导loss.backward()#进行梯度下降,沿着梯度的反方向,更新w和b的值#沿着当前小批量计算的得到的梯度(导数)更新w和b#如果导数为0(Δw,Δb为0),则w、b保存不变w.data -=0.01 * w.grad.datab.data -=0.01 * b.grad.dataprint("epoch(%d) batch(%d) lossΔw,Δb,w,b, = %.3lf,%.3lf,%.3lf,%.3lf,%.3lf," %(epoch,batch_index,loss.item(),w.grad.data,b.grad.data,w.data,b.data))#清空张量w和b中的梯度信息,为下一次迭代做准备w.grad.zero_()b.grad.zero_()#每次迭代,都打印当前迭代的轮数epoch#数据的批次batch idx和loss损失值
相关文章:

线形回归与小批量梯度下降实例
1、准备数据集 import numpy as np import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader from torch.utils.data import TensorDataset######################################################################### #################准备若干个随机的x和…...

SpringCloud微服务:基于Nacos组件,整合Dubbo框架
dubbo和fegin的差异 一、Feign与Dubbo概述 Feign是一个声明式的Web服务客户端,使得编写HTTP客户端变得更简单。通过简单的注解,Feign将自动生成HTTP请求,使得服务调用更加便捷。而Dubbo是一个高性能、轻量级的Java RPC框架,提供了…...

Golang 简要概述
文章目录 1. Golang 的学习方向2. Golang 的应用领域2.1 区块链的应用开发2.2 后台的服务应用2.3 云计算/云服务后台应用 1. Golang 的学习方向 Go 语言,我们可以简单的写成 Golang 2. Golang 的应用领域 2.1 区块链的应用开发 2.2 后台的服务应用 2.3 云计算/云服…...

web前端第三次作业---制作可提交的用户注册表
制作可提交的用户注册表: 代码: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</tit…...
教育邮箱的魔力:免费获取Adobe和JetBrains软件
今天想和大家聊聊一个超级实用的话题——如何利用Edu教育邮箱来免费获取Photoshop等Adobe系列软件,以及JetBrains的各种开发工具。 Edu邮箱的价值 首先,Edu邮箱真的是个宝藏!如果你在学校或教育机构注册过,通常会获得一个这样的…...
sympy常用函数与错误笔记
文章目录 前言一、sympy基本函数介绍变量定义1. sp.Symbol("x") 或 sp.symbols("m n")2. sp.Function("y")3. func(x).diff(x, n) 定义方程与求解符号1. sp.Eq(lhs, rhs)2. 求解函数(*代表了常用且重要,其他部分作为拓展&…...
47_Lua文件IO操作
文件I/O(Input/Output)操作在Lua中用于与外部文件进行交互,包括读取文件中的数据和将数据写入文件。Lua提供了两种模式来进行文件操作:简单模式和完全模式。下面将详细介绍这两种模式的基本使用。 1.简单模式 1.1 简单模式介绍 简单模式提供了基本的文件操作功能,它主要…...
nginx-lua模块处理流程
一. 简述: nginx的模块化设计使得每一个http模块可以只专注于完成一个独立的,简单的功能。一个请求的完整处理过程可以由多个http模块共同协作完成,这种设计具有简单性,测试性,扩展性,灵活性。关于nginx 的…...

【大数据】机器学习-----最开始的引路
以下是关于机器学习的一些基本信息,包括基本术语、假设空间、归纳偏好、发展历程、应用现状和代码示例: 一、基本术语 样本(Sample): 也称为实例(Instance)或数据点(Data Point&…...
【前端】自学基础算法 -- 21.图的广度优先搜索
图的广度优先搜索 简介 图的广度优先搜索,沿着图的宽度遍历图的节点,先访问离起始节点最近的节点,然后逐渐向外扩展。 基本步骤: 选择一个起始节点作为当前节点。将当前节点加入队列。当队列不为空时,重复以下步骤…...

ChatGPT与Claude AI:两大生成式对话模型的比较分析
自ChatGPT推出以来,这款强大的AI聊天机器人迅速吸引了全球的关注。其出色的对话能力和多样化的应用场景,成为许多人初次体验基于大规模语言模型的潜力。然而,在这个快速发展的领域中,另一款AI也在悄然崭露头角,那就是由…...
前端开发:盒子模型、块元素
1.border边框 *{box-sizing:border-box; } //使所有边框不再撑大盒子模型 粗细 : border-width 样式 : border-style, 默认没边框 . solid 实线边框 dashed 虚线边框 dotted 点线边框 颜色 : border-color div { width : 200px ; height : 200px ; border : …...
升级 CentOS 7.x 系统内核到 4.4 版本
问题描述 在 CentOS 7.x 系统中,默认内核版本是 3.10.x,这个版本可能会带来一些与 Docker 和 Kubernetes 兼容性的问题,导致系统性能不稳定或功能异常。为了提高系统的稳定性和兼容性,建议升级到更高版本的内核,例如 …...

播放音频文件同步音频文本
播放音频同步音频文本 对应单个文本高亮显示 使用audio音频文件对应音频文本资源 音频文本内容(Json) [{"end": 4875,"index": 0,"speaker": 0,"start": 30,"text": "70号二啊,","tex…...

springboot使用Easy Excel导出列表数据为Excel
springboot使用Easy Excel导出列表数据为Excel Easy Excel官网:https://easyexcel.opensource.alibaba.com/docs/current/quickstart/write 主要记录一下引入时候的pom,直接引入会依赖冲突 解决方法: <!-- 引入Easy Excel的依赖 -->&l…...

day07_Spark SQL
文章目录 day07_Spark SQL课程笔记一、今日课程内容二、Spark SQL函数定义(掌握)1、窗口函数2、自定义函数背景2.1 回顾函数分类标准:SQL最开始是_内置函数&自定义函数_两种 2.2 自定义函数背景 3、Spark原生自定义UDF函数3.1 自定义函数流程&#x…...

高性能现代PHP全栈框架 Spiral
概述 Spiral Framework 诞生于现实世界的软件开发项目是一个现代 PHP 框架,旨在为更快、更清洁、更卓越的软件开发提供动力。 特性 高性能 由于其设计以及复杂精密的应用服务器,Spiral Framework框架在不影响代码质量以及与常用库的兼容性的情况下&a…...

LeetCode - #182 Swift 实现找出重复的电子邮件
网罗开发 (小红书、快手、视频号同名) 大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等…...
《解锁鸿蒙Next系统人工智能语音助手开发的关键步骤》
在当今数字化时代,鸿蒙Next系统与人工智能的融合为开发者带来了前所未有的机遇,开发一款人工智能语音助手应用更是备受关注。以下是在鸿蒙Next系统上开发人工智能语音助手应用的关键步骤: 环境搭建与权限申请 安装开发工具:首先需…...

【Linux网络编程】数据链路层 | MAC帧 | ARP协议
前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站 🌈个人主页: 南桥几晴秋 🌈C专栏: 南桥谈C 🌈C语言专栏: C语言学习系…...

SpringBoot-17-MyBatis动态SQL标签之常用标签
文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

C++_核心编程_多态案例二-制作饮品
#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为:煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例,提供抽象制作饮品基类,提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...

el-switch文字内置
el-switch文字内置 效果 vue <div style"color:#ffffff;font-size:14px;float:left;margin-bottom:5px;margin-right:5px;">自动加载</div> <el-switch v-model"value" active-color"#3E99FB" inactive-color"#DCDFE6"…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案
随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...
LLM基础1_语言模型如何处理文本
基于GitHub项目:https://github.com/datawhalechina/llms-from-scratch-cn 工具介绍 tiktoken:OpenAI开发的专业"分词器" torch:Facebook开发的强力计算引擎,相当于超级计算器 理解词嵌入:给词语画"…...
稳定币的深度剖析与展望
一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...
React---day11
14.4 react-redux第三方库 提供connect、thunk之类的函数 以获取一个banner数据为例子 store: 我们在使用异步的时候理应是要使用中间件的,但是configureStore 已经自动集成了 redux-thunk,注意action里面要返回函数 import { configureS…...

HashMap中的put方法执行流程(流程图)
1 put操作整体流程 HashMap 的 put 操作是其最核心的功能之一。在 JDK 1.8 及以后版本中,其主要逻辑封装在 putVal 这个内部方法中。整个过程大致如下: 初始判断与哈希计算: 首先,putVal 方法会检查当前的 table(也就…...
Java求职者面试指南:Spring、Spring Boot、Spring MVC与MyBatis技术解析
Java求职者面试指南:Spring、Spring Boot、Spring MVC与MyBatis技术解析 一、第一轮基础概念问题 1. Spring框架的核心容器是什么?它的作用是什么? Spring框架的核心容器是IoC(控制反转)容器。它的主要作用是管理对…...