当前位置: 首页 > news >正文

【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 基于亚马逊的MXNet库
  2. 本专栏是对李沐博士的《动手学深度学习》的笔记,仅用于分享个人学习思考
  3. 以下是本专栏所需的环境(放进一个environment.yml,然后用conda虚拟环境统一配置即可)
  4. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  5. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法
name: gluon
dependencies:
- python=3.6
- pip:- mxnet==1.5.0- d2lzh==1.0.0- jupyter==1.0.0- matplotlib==2.2.2- pandas==0.23.4

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

相关文章:

【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面 基于亚马逊的MXNet库本专栏是对李沐博士的《动手学深度学习》的笔记,仅用于分享个人学习思考以下是本专栏所需的环境(放进一个environment.yml,然后用conda虚拟环境统一配置即可)刚开始先从普通的寻优算法开始&#xff…...

unity中的动画混合树

为什么需要动画混合树,动画混合树有什么作用? 在Unity中,动画混合树(Animation Blend Tree)是一种用于管理和混合多个动画状态的工具,包括1D和2D两种类型,以下是其作用及使用必要性的介绍&…...

《基于deepseek R1开源大模型的电子数据取证技术发展研究》

《基于deepseek R1开源大模型的电子数据取证技术发展研究》 摘要 本文探讨了基于deepseek R1开源大模型的电子数据取证技术发展前景。随着人工智能技术的快速发展,AI大模型在电子数据取证领域的应用潜力日益凸显。本研究首先分析了电子数据取证的现状和挑战&#xf…...

Potplayer常用快捷键

Potplayer是一个非常好用的播放器,功能强大 功能快捷键播放/暂停空格键退出Esc下一帧F上一帧D快进10秒→快退10秒←快进30秒Ctrl →快退30秒Ctrl ←快进1分钟Alt →快退1分钟Alt ←增加播放速度C减少播放速度X恢复正常速度Z增加音量↑减少音量↓静音M显示/隐藏字幕Ctrl A…...

C++ Primer 自定义数据结构

欢迎阅读我的 【CPrimer】专栏 专栏简介:本专栏主要面向C初学者,解释C的一些基本概念和基础语言特性,涉及C标准库的用法,面向对象特性,泛型特性高级用法。通过使用标准库中定义的抽象设施,使你更加适应高级…...

35.Word:公积金管理中心文员小谢【37】

目录 Word1.docx ​ Word2.docx Word2.docx ​ 注意本套题还是与上一套存在不同之处 Word1.docx 布局样式的应用设计页眉页脚位置在水平/垂直方向上均相对于外边距居中排列:格式→大小对话框→位置→水平/垂直 按下表所列要求将原文中的手动纯文本编号分别替换…...

北京钟鼓楼:立春“鞭春牛”,钟鼓迎春来

仁风导和气,勾芒御昊春。“钟鼓迎春”立春鞭春牛民俗体验活动于立春当日在北京钟鼓楼隆重举办。此次活动由北京市钟鼓楼文物保管所主办,京睿文(北京)文化科技有限公司承办,通过礼官报春、击鼓鸣钟、春娃喊春、中国时间文化角色巡游、鞭春牛等一系列精彩的活动环节,为观众呈现了…...

股票入门知识

股票入门(更适合中国宝宝体制) 股市基础知识 本文介绍了股票的基础知识,股票的分类,各板块发行上市条件,股票代码,交易时间,交易规则,炒股术语,影响股价的因素&#xf…...

Java自定义IO密集型和CPU密集型线程池

文章目录 前言线程池各类场景描述常见场景案例设计思路公共类自定义工厂类-MyThreadFactory自定义拒绝策略-RejectedExecutionHandlerFactory自定义阻塞队列-TaskQueue(实现 核心线程->最大线程数->队列) 场景1:CPU密集型场景思路&…...

Git的安装步骤详解(复杂的安装界面该如何勾选?)

目录 一、下载与安装 1.官网下载git 2、下载完成之后,双击下载好的exe文件进行安装 3、选择Git的安装路径 4、选择在安装 Git 时要包含的组件和功能 5、选择 Git 快捷方式在 Windows 开始菜单中的位置。 6、选择 Git 使用的默认编辑器 7、调整新仓库中初始分…...

文本预处理

一、文本的基本单位 1、Token 定义:文本的最小单位,例如单词、标点符号。 示例: 原句: "I love NLP." 分词结果: [I, love, NLP, .] 2、语法与语义 语法:词的结构和句子的组合规则。 语义&a…...

SQLAlchemy 2.0的简单使用教程

SQLAlchemy 2.0相比1.x进行了很大的更新,目前网上的教程不多,以下以链接mysql为例介绍一下基本的使用方法 环境及依赖 Python:3.8 mysql:8.3 Flask:3.0.3 SQLAlchemy:2.0.37 PyMySQL:1.1.1使用步骤 1、创建引擎,链接到mysql engine crea…...

基于RAG的知识库问答系统

基于RAG的知识库问答系统 结合语义检索与大语言模型技术,实现基于私有知识库的智能问答解决方案。采用两阶段处理架构,可快速定位相关文档并生成精准回答。 核心功能 知识向量化引擎 支持多语言文本嵌入(all-MiniLM-L6-v2模型)自…...

SQL/Panda映射关系

Pandas教程(非常详细)_pandas 教程-CSDN博客 SQL:使用SELECT col_1, col_2 FROM tab; Pandas:使用df[[col_1, col_2]]。 SQL:使用SELECT * FROM tab WHERE col_1 11 AND col_2 > 5; Pandas:使用df…...

自定义数据集 使用paddlepaddle框架实现逻辑回归

导入必要的库 import numpy as np import paddle import paddle.nn as nn 数据准备: seed1 paddle.seed(seed)# 1.散点输入 定义输入数据 data [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6…...

Docker入门篇(Docker基础概念与Linux安装教程)

目录 一、什么是Docker、有什么作用 二、Docker与虚拟机(对比) 三、Docker基础概念 四、CentOS安装Docker 一、从零认识Docker、有什么作用 1.项目部署可能的问题: 大型项目组件较多,运行环境也较为复杂,部署时会碰到一些问题&#xff1…...

c/c++高级编程

1.避免变量冗余初始化 结构体初始化为0,等价于对该内存进行一次memset,对于较大的结构体或者热点函数,重复的赋值带来冗余的性能开销。现代编译器对此类冗余初始化代码具有一定的优化能力,因此,打开相关的编译选项的优…...

2024-我的学习成长之路

因为热爱,无畏山海...

vscode软件操作界面UI布局@各个功能区域划分及其名称称呼

文章目录 abstract检查用户界面的主要区域官方文档关于UI的介绍 abstract 检查 Visual Studio Code 用户界面 - Training | Microsoft Learn 本质上,Visual Studio Code 是一个代码编辑器,其用户界面和布局与许多其他代码编辑器相似。 界面左侧是用于访…...

xmind使用教程

xmind使用教程 前言xmind版本信息“xmind使用教程”的xmind思维导图 前言 首先xmind是什么?XMind 是一款思维导图和头脑风暴工具,用于帮助用户组织和可视化思维、创意和信息。它允许用户通过图形化的方式来创建、整理和分享思维导图,可以用于…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

可靠性+灵活性:电力载波技术在楼宇自控中的核心价值

可靠性灵活性:电力载波技术在楼宇自控中的核心价值 在智能楼宇的自动化控制中,电力载波技术(PLC)凭借其独特的优势,正成为构建高效、稳定、灵活系统的核心解决方案。它利用现有电力线路传输数据,无需额外布…...

LeetCode - 394. 字符串解码

题目 394. 字符串解码 - 力扣(LeetCode) 思路 使用两个栈:一个存储重复次数,一个存储字符串 遍历输入字符串: 数字处理:遇到数字时,累积计算重复次数左括号处理:保存当前状态&a…...

零基础设计模式——行为型模式 - 责任链模式

第四部分:行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习!行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想:使多个对象都有机会处…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索(基于物理空间 广播范围)2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

关键领域软件测试的突围之路:如何破解安全与效率的平衡难题

在数字化浪潮席卷全球的今天,软件系统已成为国家关键领域的核心战斗力。不同于普通商业软件,这些承载着国家安全使命的软件系统面临着前所未有的质量挑战——如何在确保绝对安全的前提下,实现高效测试与快速迭代?这一命题正考验着…...

Android第十三次面试总结(四大 组件基础)

Activity生命周期和四大启动模式详解 一、Activity 生命周期 Activity 的生命周期由一系列回调方法组成,用于管理其创建、可见性、焦点和销毁过程。以下是核心方法及其调用时机: ​onCreate()​​ ​调用时机​:Activity 首次创建时调用。​…...

佰力博科技与您探讨热释电测量的几种方法

热释电的测量主要涉及热释电系数的测定,这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中,积分电荷法最为常用,其原理是通过测量在电容器上积累的热释电电荷,从而确定热释电系数…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

宇树科技,改名了!

提到国内具身智能和机器人领域的代表企业,那宇树科技(Unitree)必须名列其榜。 最近,宇树科技的一项新变动消息在业界引发了不少关注和讨论,即: 宇树向其合作伙伴发布了一封公司名称变更函称,因…...