当前位置: 首页 > news >正文

【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 基于亚马逊的MXNet库
  2. 本专栏是对李沐博士的《动手学深度学习》的笔记,仅用于分享个人学习思考
  3. 以下是本专栏所需的环境(放进一个environment.yml,然后用conda虚拟环境统一配置即可)
  4. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  5. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法
name: gluon
dependencies:
- python=3.6
- pip:- mxnet==1.5.0- d2lzh==1.0.0- jupyter==1.0.0- matplotlib==2.2.2- pandas==0.23.4

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

相关文章:

【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面 基于亚马逊的MXNet库本专栏是对李沐博士的《动手学深度学习》的笔记,仅用于分享个人学习思考以下是本专栏所需的环境(放进一个environment.yml,然后用conda虚拟环境统一配置即可)刚开始先从普通的寻优算法开始&#xff…...

unity中的动画混合树

为什么需要动画混合树,动画混合树有什么作用? 在Unity中,动画混合树(Animation Blend Tree)是一种用于管理和混合多个动画状态的工具,包括1D和2D两种类型,以下是其作用及使用必要性的介绍&…...

《基于deepseek R1开源大模型的电子数据取证技术发展研究》

《基于deepseek R1开源大模型的电子数据取证技术发展研究》 摘要 本文探讨了基于deepseek R1开源大模型的电子数据取证技术发展前景。随着人工智能技术的快速发展,AI大模型在电子数据取证领域的应用潜力日益凸显。本研究首先分析了电子数据取证的现状和挑战&#xf…...

Potplayer常用快捷键

Potplayer是一个非常好用的播放器,功能强大 功能快捷键播放/暂停空格键退出Esc下一帧F上一帧D快进10秒→快退10秒←快进30秒Ctrl →快退30秒Ctrl ←快进1分钟Alt →快退1分钟Alt ←增加播放速度C减少播放速度X恢复正常速度Z增加音量↑减少音量↓静音M显示/隐藏字幕Ctrl A…...

C++ Primer 自定义数据结构

欢迎阅读我的 【CPrimer】专栏 专栏简介:本专栏主要面向C初学者,解释C的一些基本概念和基础语言特性,涉及C标准库的用法,面向对象特性,泛型特性高级用法。通过使用标准库中定义的抽象设施,使你更加适应高级…...

35.Word:公积金管理中心文员小谢【37】

目录 Word1.docx ​ Word2.docx Word2.docx ​ 注意本套题还是与上一套存在不同之处 Word1.docx 布局样式的应用设计页眉页脚位置在水平/垂直方向上均相对于外边距居中排列:格式→大小对话框→位置→水平/垂直 按下表所列要求将原文中的手动纯文本编号分别替换…...

北京钟鼓楼:立春“鞭春牛”,钟鼓迎春来

仁风导和气,勾芒御昊春。“钟鼓迎春”立春鞭春牛民俗体验活动于立春当日在北京钟鼓楼隆重举办。此次活动由北京市钟鼓楼文物保管所主办,京睿文(北京)文化科技有限公司承办,通过礼官报春、击鼓鸣钟、春娃喊春、中国时间文化角色巡游、鞭春牛等一系列精彩的活动环节,为观众呈现了…...

股票入门知识

股票入门(更适合中国宝宝体制) 股市基础知识 本文介绍了股票的基础知识,股票的分类,各板块发行上市条件,股票代码,交易时间,交易规则,炒股术语,影响股价的因素&#xf…...

Java自定义IO密集型和CPU密集型线程池

文章目录 前言线程池各类场景描述常见场景案例设计思路公共类自定义工厂类-MyThreadFactory自定义拒绝策略-RejectedExecutionHandlerFactory自定义阻塞队列-TaskQueue(实现 核心线程->最大线程数->队列) 场景1:CPU密集型场景思路&…...

Git的安装步骤详解(复杂的安装界面该如何勾选?)

目录 一、下载与安装 1.官网下载git 2、下载完成之后,双击下载好的exe文件进行安装 3、选择Git的安装路径 4、选择在安装 Git 时要包含的组件和功能 5、选择 Git 快捷方式在 Windows 开始菜单中的位置。 6、选择 Git 使用的默认编辑器 7、调整新仓库中初始分…...

文本预处理

一、文本的基本单位 1、Token 定义:文本的最小单位,例如单词、标点符号。 示例: 原句: "I love NLP." 分词结果: [I, love, NLP, .] 2、语法与语义 语法:词的结构和句子的组合规则。 语义&a…...

SQLAlchemy 2.0的简单使用教程

SQLAlchemy 2.0相比1.x进行了很大的更新,目前网上的教程不多,以下以链接mysql为例介绍一下基本的使用方法 环境及依赖 Python:3.8 mysql:8.3 Flask:3.0.3 SQLAlchemy:2.0.37 PyMySQL:1.1.1使用步骤 1、创建引擎,链接到mysql engine crea…...

基于RAG的知识库问答系统

基于RAG的知识库问答系统 结合语义检索与大语言模型技术,实现基于私有知识库的智能问答解决方案。采用两阶段处理架构,可快速定位相关文档并生成精准回答。 核心功能 知识向量化引擎 支持多语言文本嵌入(all-MiniLM-L6-v2模型)自…...

SQL/Panda映射关系

Pandas教程(非常详细)_pandas 教程-CSDN博客 SQL:使用SELECT col_1, col_2 FROM tab; Pandas:使用df[[col_1, col_2]]。 SQL:使用SELECT * FROM tab WHERE col_1 11 AND col_2 > 5; Pandas:使用df…...

自定义数据集 使用paddlepaddle框架实现逻辑回归

导入必要的库 import numpy as np import paddle import paddle.nn as nn 数据准备: seed1 paddle.seed(seed)# 1.散点输入 定义输入数据 data [[-0.5, 7.7], [1.8, 98.5], [0.9, 57.8], [0.4, 39.2], [-1.4, -15.7], [-1.4, -37.3], [-1.8, -49.1], [1.5, 75.6…...

Docker入门篇(Docker基础概念与Linux安装教程)

目录 一、什么是Docker、有什么作用 二、Docker与虚拟机(对比) 三、Docker基础概念 四、CentOS安装Docker 一、从零认识Docker、有什么作用 1.项目部署可能的问题: 大型项目组件较多,运行环境也较为复杂,部署时会碰到一些问题&#xff1…...

c/c++高级编程

1.避免变量冗余初始化 结构体初始化为0,等价于对该内存进行一次memset,对于较大的结构体或者热点函数,重复的赋值带来冗余的性能开销。现代编译器对此类冗余初始化代码具有一定的优化能力,因此,打开相关的编译选项的优…...

2024-我的学习成长之路

因为热爱,无畏山海...

vscode软件操作界面UI布局@各个功能区域划分及其名称称呼

文章目录 abstract检查用户界面的主要区域官方文档关于UI的介绍 abstract 检查 Visual Studio Code 用户界面 - Training | Microsoft Learn 本质上,Visual Studio Code 是一个代码编辑器,其用户界面和布局与许多其他代码编辑器相似。 界面左侧是用于访…...

xmind使用教程

xmind使用教程 前言xmind版本信息“xmind使用教程”的xmind思维导图 前言 首先xmind是什么?XMind 是一款思维导图和头脑风暴工具,用于帮助用户组织和可视化思维、创意和信息。它允许用户通过图形化的方式来创建、整理和分享思维导图,可以用于…...

Day33【AI思考】-分层递进式结构 对数学数系的 终极系统分类

文章目录 **分层递进式结构** 对数学数系的 **终极系统分类**总览**一、数系演化树(纵向维度)**数系扩展逻辑树**数系扩展逻辑** **二、代数结构对照表(横向维度)**数系扩展的数学意义 **三、几何对应图谱(空间维度&am…...

k8s二进制集群之ETCD集群证书生成

安装cfssl工具配置CA证书请求文件创建CA证书创建CA证书策略配置etcd证书请求文件生成etcd证书 继续上一篇文章《负载均衡器高可用部署》下面介绍一下etcd证书生成配置。其中涉及到的ip地址和证书基本信息请替换成你自己的信息。 安装cfssl工具 下载cfssl安装包 https://github…...

MySQL5.5升级到MySQL5.7

【卸载原来的MySQL】 cmd打开命令提示符窗口(管理员身份)net stop mysql(先停止MySQL服务) 3.卸载 切换到原来5.5版本的bin目录,输入mysqld remove卸载服务 测试mysql -V查看Mysql版本还是5.5 查看了环境变量里的…...

Golang Gin系列-9:Gin 集成Swagger生成文档

文档一直是一项乏味的工作(以我个人的拙见),但也是编码过程中最重要的任务之一。在本文中,我们将学习如何将Swagger规范与Gin框架集成。我们将实现JWT认证,请求体作为表单数据和JSON。这里唯一的先决条件是Gin服务器。…...

利用Python高效处理大规模词汇数据

在本篇博客中,我们将探讨如何使用Python及其强大的库来处理和分析大规模的词汇数据。我们将介绍如何从多个.pkl文件中读取数据,并应用一系列算法来筛选和扩展一个核心词汇列表。这个过程涉及到使用Pandas、Polars以及tqdm等库来实现高效的数据处理。 引…...

【PyQt】超级超级笨的pyqt计算器案例

计算器 1.QT Designer设计外观 1.pushButton2.textEdit3.groupBox4.布局设计 2.加载ui文件 导入模块: sys:用于处理命令行参数。 QApplication:PyQt5 应用程序类。 QWidget:窗口基类。 uic:用于加载 .ui 文件。…...

Git 的起源与发展

序章:版本控制的前世今生 在软件开发的漫长旅程中,版本控制犹如一位忠诚的伙伴,始终陪伴着开发者们。它的存在,解决了软件开发过程中代码管理的诸多难题,让团队协作更加高效,代码的演进更加有序。 简单来…...

预防和应对DDoS的方法

DDoS发起者通过大量的网络流量来中断服务器、服务或网络的正常运行,通常由多个受感染的计算机或联网设备(包括物联网设备)发起。 换种通俗的说法,可以将其想象成高速公路上的一次突然的大规模交通堵塞,阻止了正常的通勤…...

51单片机开发:独立按键实验

实验目的:按下键盘1时,点亮LED灯1。 键盘原理图如下图所示,可见,由于接GND,当键盘按下时,P3相应的端口为低电平。 键盘按下时会出现抖动,时间通常为5-10ms,代码中通过延时函数delay…...

02.04 数据类型

请写出以下几个数据的类型: 整数 a ----->int a的地址 ----->int* 存放a的数组b ----->int[] 存放a的地址的数组c ----->int*[] b的地址 ----->int* c的地址 ----->int** 指向printf函数的指针d ----->int (*)(const char*, ...) …...