线形回归与小批量梯度下降实例
1、准备数据集
import numpy as np
import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader
from torch.utils.data import TensorDataset#########################################################################
#################准备若干个随机的x和y#####################################
#########################################################################
np.random.seed(100) #使用random.seed,设置一个固定的随机种子,
data_size = 150 # 数据集大小
x_range = 5 # x的范围
iteration_count = 100 # 迭代次数# np.random.rand 是 NumPy 库中的一个函数,用于生成一个给定形状的数组,
# 数组中的元素是从一个均匀分布的样本中抽取的,这个均匀分布是在半开区间 [0, 1) 上。
# 这意味着产生的随机数将大于等于0且小于1。#随机生成data_size个横坐标x,范围在0到x_range之间
x=x_range * np.random.rand(data_size,1)#生成带有噪音y数据,基本分布在y=2x+6的附近
y=2*x + 6 + np.random.randn(data_size,1)*0.3plt.scatter(x,y,marker='x',color='green')#########################################################################
#################将训练数据转为张量#######################################
#########################################################################
#将训练数据转为张量
tensorX = torch.from_numpy(x).float()
tensorY = torch.from_numpy(y).float()#使用TensorDataset,将tensorX和tensorY组成训练集
dataset = TensorDataset(tensorX,tensorY)#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,batch_size = 20, #每一个小批量的数据规模是20shuffle =True ) #随机打乱数据的顺序
print("dataloader len =%d" %(len(dataloader)))for index,(data,label) in enumerate(dataloader):print("index=%d num = %d"%(index,len(data)))
2、线性回归模型的训练思路
2.1 初始化参数
设置是模型参数:权重w和偏置b
初始化为随机值
设置 requires_grad=True,PyTorch 将记录这些张量的操作历史,用于后续的自动求导
2.2 循环训练(epoch)
epoch 变量用于控制整个训练过程的迭代轮数
在机器学习和深度学习中,“epoch” 是一个常用的术语,指的是在整个数据集上完整地运行一次(即正向传播和反向传播)训练算法的过程。
定义:一个 epoch 是指训练过程中,训练集中每个样本都被使用过一次来更新模型的权重。
训练过程:在训练一个模型时,通常会将数据集分成多个批次(batches)。每个批次包含一定数量的样本。一个 epoch 完成意味着所有批次都已经过模型处理。
迭代与epoch:在一个 epoch 内,模型可能会多次迭代,每次迭代处理一个批次的数据。因此,一个 epoch 包含多个迭代(iterations)。
目的:通过多个 epochs 的训练,模型可以逐渐学习数据集中的模式,从而提高其性能。
数量:训练一个模型所需的 epochs 数量取决于多种因素,包括数据集的大小、模型的复杂度以及问题的难度。有时可能只需要几个 epochs,而有时可能需要数百甚至数千个 epochs。
监控:在训练过程中,通常会监控每个 epoch 的性能指标(如损失函数的值或准确率),以评估模型的学习进度。
过拟合与欠拟合:如果训练过多的 epochs,模型可能会过拟合(即模型学习到了数据中的噪声而非潜在的模式),而训练不足的 epochs 则可能导致欠拟合(即模型未能捕捉到数据中的关键模式)。
2.3 数据加载
内层循环通过 dataloader 遍历训练数据集的小批量数据。dataloader 是一个数据加载器,通常由 DataLoader 类创建,用于批量加载数据。
2.4 前向传播
假设 tensorX 是当前批次的数据
tensorY 是对应的真实标签
使用当前参数 w 和 b 计算预测值 h=w*tensorX +b。
2.5 计算损失
计算预测值 h 和真实值 tensorY 之间的均方误差(MSE),并保存到 loss
loss = torch.mean((h - tensorY) ** 2)
2.6 反向传播
调用 loss.backward() 进行反向传播,计算损失关于参数 w 和 b 的梯度
设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导
2.7 更新参数
使用梯度下降算法更新参数 w 和 b。学习率设置为0.01
w.data -= 0.01 * w.grad.data
b.data -= 0.01 * b.grad.data
沿着当前小批量计算的得到的梯度(导数)更新w和b
如果导数为0,则w、b保存不变
2.8 梯度清零
在每次迭代后,需要清空参数的梯度信息,以便下一次迭代计算
3、线性回归模型的实现
# 待送代的参数为w和b
w = torch.randn(1,requires_grad=True)
b = torch.randn(1,requires_grad=True)#进入模型的循环迭代
for epoch in range(1,iteration_count):#代表了整个训练数据集的迭代轮数# 在一个迭代轮次中,以小批量的方式,使用dataloader对数据# batch_index表示当前遍历的批次# data和label表示这个批次的训练数据和标记for batch_index,(data, label)in enumerate(dataloader):h = tensorX * w + b #计算当前直线的预测值,保存到h#计算预测值h和真实值y之间的均方误差,保存到loss中loss=torch.mean((h-tensorY)**2)#计算代价1oss关于参数w和b的偏导数,设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导loss.backward()#进行梯度下降,沿着梯度的反方向,更新w和b的值#沿着当前小批量计算的得到的梯度(导数)更新w和b#如果导数为0(Δw,Δb为0),则w、b保存不变w.data -=0.01 * w.grad.datab.data -=0.01 * b.grad.dataprint("epoch(%d) batch(%d) lossΔw,Δb,w,b, = %.3lf,%.3lf,%.3lf,%.3lf,%.3lf," %(epoch,batch_index,loss.item(),w.grad.data,b.grad.data,w.data,b.data))#清空张量w和b中的梯度信息,为下一次迭代做准备w.grad.zero_()b.grad.zero_()#每次迭代,都打印当前迭代的轮数epoch#数据的批次batch idx和loss损失值

相关文章:
线形回归与小批量梯度下降实例
1、准备数据集 import numpy as np import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader from torch.utils.data import TensorDataset######################################################################### #################准备若干个随机的x和…...
SpringCloud微服务:基于Nacos组件,整合Dubbo框架
dubbo和fegin的差异 一、Feign与Dubbo概述 Feign是一个声明式的Web服务客户端,使得编写HTTP客户端变得更简单。通过简单的注解,Feign将自动生成HTTP请求,使得服务调用更加便捷。而Dubbo是一个高性能、轻量级的Java RPC框架,提供了…...
Golang 简要概述
文章目录 1. Golang 的学习方向2. Golang 的应用领域2.1 区块链的应用开发2.2 后台的服务应用2.3 云计算/云服务后台应用 1. Golang 的学习方向 Go 语言,我们可以简单的写成 Golang 2. Golang 的应用领域 2.1 区块链的应用开发 2.2 后台的服务应用 2.3 云计算/云服…...
web前端第三次作业---制作可提交的用户注册表
制作可提交的用户注册表: 代码: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</tit…...
教育邮箱的魔力:免费获取Adobe和JetBrains软件
今天想和大家聊聊一个超级实用的话题——如何利用Edu教育邮箱来免费获取Photoshop等Adobe系列软件,以及JetBrains的各种开发工具。 Edu邮箱的价值 首先,Edu邮箱真的是个宝藏!如果你在学校或教育机构注册过,通常会获得一个这样的…...
sympy常用函数与错误笔记
文章目录 前言一、sympy基本函数介绍变量定义1. sp.Symbol("x") 或 sp.symbols("m n")2. sp.Function("y")3. func(x).diff(x, n) 定义方程与求解符号1. sp.Eq(lhs, rhs)2. 求解函数(*代表了常用且重要,其他部分作为拓展&…...
47_Lua文件IO操作
文件I/O(Input/Output)操作在Lua中用于与外部文件进行交互,包括读取文件中的数据和将数据写入文件。Lua提供了两种模式来进行文件操作:简单模式和完全模式。下面将详细介绍这两种模式的基本使用。 1.简单模式 1.1 简单模式介绍 简单模式提供了基本的文件操作功能,它主要…...
nginx-lua模块处理流程
一. 简述: nginx的模块化设计使得每一个http模块可以只专注于完成一个独立的,简单的功能。一个请求的完整处理过程可以由多个http模块共同协作完成,这种设计具有简单性,测试性,扩展性,灵活性。关于nginx 的…...
【大数据】机器学习-----最开始的引路
以下是关于机器学习的一些基本信息,包括基本术语、假设空间、归纳偏好、发展历程、应用现状和代码示例: 一、基本术语 样本(Sample): 也称为实例(Instance)或数据点(Data Point&…...
【前端】自学基础算法 -- 21.图的广度优先搜索
图的广度优先搜索 简介 图的广度优先搜索,沿着图的宽度遍历图的节点,先访问离起始节点最近的节点,然后逐渐向外扩展。 基本步骤: 选择一个起始节点作为当前节点。将当前节点加入队列。当队列不为空时,重复以下步骤…...
ChatGPT与Claude AI:两大生成式对话模型的比较分析
自ChatGPT推出以来,这款强大的AI聊天机器人迅速吸引了全球的关注。其出色的对话能力和多样化的应用场景,成为许多人初次体验基于大规模语言模型的潜力。然而,在这个快速发展的领域中,另一款AI也在悄然崭露头角,那就是由…...
前端开发:盒子模型、块元素
1.border边框 *{box-sizing:border-box; } //使所有边框不再撑大盒子模型 粗细 : border-width 样式 : border-style, 默认没边框 . solid 实线边框 dashed 虚线边框 dotted 点线边框 颜色 : border-color div { width : 200px ; height : 200px ; border : …...
升级 CentOS 7.x 系统内核到 4.4 版本
问题描述 在 CentOS 7.x 系统中,默认内核版本是 3.10.x,这个版本可能会带来一些与 Docker 和 Kubernetes 兼容性的问题,导致系统性能不稳定或功能异常。为了提高系统的稳定性和兼容性,建议升级到更高版本的内核,例如 …...
播放音频文件同步音频文本
播放音频同步音频文本 对应单个文本高亮显示 使用audio音频文件对应音频文本资源 音频文本内容(Json) [{"end": 4875,"index": 0,"speaker": 0,"start": 30,"text": "70号二啊,","tex…...
springboot使用Easy Excel导出列表数据为Excel
springboot使用Easy Excel导出列表数据为Excel Easy Excel官网:https://easyexcel.opensource.alibaba.com/docs/current/quickstart/write 主要记录一下引入时候的pom,直接引入会依赖冲突 解决方法: <!-- 引入Easy Excel的依赖 -->&l…...
day07_Spark SQL
文章目录 day07_Spark SQL课程笔记一、今日课程内容二、Spark SQL函数定义(掌握)1、窗口函数2、自定义函数背景2.1 回顾函数分类标准:SQL最开始是_内置函数&自定义函数_两种 2.2 自定义函数背景 3、Spark原生自定义UDF函数3.1 自定义函数流程&#x…...
高性能现代PHP全栈框架 Spiral
概述 Spiral Framework 诞生于现实世界的软件开发项目是一个现代 PHP 框架,旨在为更快、更清洁、更卓越的软件开发提供动力。 特性 高性能 由于其设计以及复杂精密的应用服务器,Spiral Framework框架在不影响代码质量以及与常用库的兼容性的情况下&a…...
LeetCode - #182 Swift 实现找出重复的电子邮件
网罗开发 (小红书、快手、视频号同名) 大家好,我是 展菲,目前在上市企业从事人工智能项目研发管理工作,平时热衷于分享各种编程领域的软硬技能知识以及前沿技术,包括iOS、前端、Harmony OS、Java、Python等…...
《解锁鸿蒙Next系统人工智能语音助手开发的关键步骤》
在当今数字化时代,鸿蒙Next系统与人工智能的融合为开发者带来了前所未有的机遇,开发一款人工智能语音助手应用更是备受关注。以下是在鸿蒙Next系统上开发人工智能语音助手应用的关键步骤: 环境搭建与权限申请 安装开发工具:首先需…...
【Linux网络编程】数据链路层 | MAC帧 | ARP协议
前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站 🌈个人主页: 南桥几晴秋 🌈C专栏: 南桥谈C 🌈C语言专栏: C语言学习系…...
浏览器访问 AWS ECS 上部署的 Docker 容器(监听 80 端口)
✅ 一、ECS 服务配置 Dockerfile 确保监听 80 端口 EXPOSE 80 CMD ["nginx", "-g", "daemon off;"]或 EXPOSE 80 CMD ["python3", "-m", "http.server", "80"]任务定义(Task Definition&…...
华为云AI开发平台ModelArts
华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...
测试微信模版消息推送
进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...
Zustand 状态管理库:极简而强大的解决方案
Zustand 是一个轻量级、快速和可扩展的状态管理库,特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...
《从零掌握MIPI CSI-2: 协议精解与FPGA摄像头开发实战》-- CSI-2 协议详细解析 (一)
CSI-2 协议详细解析 (一) 1. CSI-2层定义(CSI-2 Layer Definitions) 分层结构 :CSI-2协议分为6层: 物理层(PHY Layer) : 定义电气特性、时钟机制和传输介质(导线&#…...
Python爬虫实战:研究feedparser库相关技术
1. 引言 1.1 研究背景与意义 在当今信息爆炸的时代,互联网上存在着海量的信息资源。RSS(Really Simple Syndication)作为一种标准化的信息聚合技术,被广泛用于网站内容的发布和订阅。通过 RSS,用户可以方便地获取网站更新的内容,而无需频繁访问各个网站。 然而,互联网…...
VTK如何让部分单位不可见
最近遇到一个需求,需要让一个vtkDataSet中的部分单元不可见,查阅了一些资料大概有以下几种方式 1.通过颜色映射表来进行,是最正规的做法 vtkNew<vtkLookupTable> lut; //值为0不显示,主要是最后一个参数,透明度…...
leetcodeSQL解题:3564. 季节性销售分析
leetcodeSQL解题:3564. 季节性销售分析 题目: 表:sales ---------------------- | Column Name | Type | ---------------------- | sale_id | int | | product_id | int | | sale_date | date | | quantity | int | | price | decimal | -…...
SpringCloudGateway 自定义局部过滤器
场景: 将所有请求转化为同一路径请求(方便穿网配置)在请求头内标识原来路径,然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...

