当前位置: 首页 > article >正文

【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  2. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):num = len(features)indices = list(range(num))                                  # 生成索引数组random.shuffle(indices)                                     # 打乱indices# 该遍历方式同时确保了随机采样和无遗漏for i in range(0, num, batch_size):j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()# 定义模型
def linreg(X, w, b):return nd.dot(X,w)+b# 定义损失函数
def squared_loss(y_hat, y):return (y_hat - y.reshape(y_hat.shape)) **2 /2# 定义寻优算法
def sgd(params, learning_rate, batch_size):for param in params:# 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小param[:] = param - learning_rate * param.grad / batch_size# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数for x, y in data_iter(batch_size, features, labels):with autograd.record():l = loss(net(x,w,b), y)			# 求当前小批量的总损失l.backward()						# 求梯度sgd([w,b], lr, batch_size)			# 梯度更新参数train_l = loss(net(features,w,b), labels)print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

相关文章:

【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法 整体流程 生成训练数据集(实际工程中,需要从实际对象身上采集数据…...

20250204将Ubuntu22.04的默认Dash的shell脚本更换为bash

20250204将Ubuntu22.04的默认Dash的shell脚本更换为bash 2025/2/4 23:45 百度:dash bash https://blog.csdn.net/2201_75772333/article/details/136955776 【Linux基础】dash和bash简介 Dash(Debian Almquist Shell)和 Bash(Bou…...

Golang 并发机制-3:通道(channels)机制详解

并发编程是一种创建性能优化且响应迅速的软件的强大方法。Golang(也称为 Go)通过通道(channels)这一特性,能够可靠且优雅地实现并发通信。本文将揭示通道的概念,解释其在并发编程中的作用,并提供…...

Python处理数据库:MySQL与SQLite详解

Python处理数据库:MySQL与SQLite详解 在数据处理和存储方面,数据库扮演着至关重要的角色。Python提供了多种与数据库交互的方式,其中pymysql库用于连接和操作MySQL数据库,而SQLite则是一种轻量级的嵌入式数据库,Pytho…...

可视化大屏在石油方面的应用。

可视化大屏通过整合石油工业全链条数据,构建数字孪生驱动的运营监控体系,显著提升油气勘探、开采、储运及炼化的管理效能。其技术架构依托工业物联网(IIoT)实时采集钻井参数、管道压力、储罐液位等数据,通过OPC UA协议…...

【学术投稿-2025年计算机视觉研究进展与应用国际学术会议 (ACVRA 2025)】从计算机基础到HTML开发:Web开发的第一步

会议官网:www.acvra.org 简介 2025年计算机视觉研究进展与应用(ACVRA 2025)将于2025年2月28-3月2日在中国广州召开,将汇聚世界各地的顶尖学者、研究人员和行业专家,聚焦计算机视觉领域的最新研究动态与应用成就。本次…...

【 AI agents】letta:2024年代理堆栈演进(中英文翻译)

The AI agents stack AI 代理堆栈 November 14, 2024 11月 14, 2024原文: The AI agents stack官方教程教程学习笔记: 【memgpt】letta 课程1/2:从头实现一个自我编辑、记忆和多步骤推理的代理Understanding the AI agents landscape 了解 AI 代理环境 Although we see a …...

Axure PR 9 旋转效果 设计交互

大家好,我是大明同学。 这期内容,我们将学习Axure中的旋转效果设计与交互技巧。 旋转 创建旋转效果所需的元件 1.打开一个新的 RP 文件并在画布上打开 Page 1。 2.在元件库中拖出一个按钮元件。 创建交互 创建按钮交互状态 1.选中按钮元件&#xf…...

Docker 部署教程jenkins

Docker 部署 jenkins 教程 Jenkins 官方网站 Jenkins 是一个开源的自动化服务器,主要用于持续集成(CI)和持续交付(CD)过程。它帮助开发人员自动化构建、测试和部署应用程序,显著提高软件开发的效率和质量…...

二、CSS笔记

(一)css概述 1、定义 CSS是Cascading Style Sheets的简称,中文称为层叠样式表,用来控制网页数据的表现,可以使网页的表现与数据内容分离。 2、要点 怎么找到标签怎么操作标签对象(element) 3、css的四种引入方式 3.1 行内式 在标签的style属性中设定CSS样式。这种方…...

计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

前一篇文章,Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started PyTorch 计算图和 Autograd 微积分之于机器学习Computational Graphs 计算图Autograd…...

基于 Java 开发的 MongoDB 企业级应用全解析

基于Java的MongoDB企业级应用开发实战 目录 背景与历史MongoDB的核心功能与特性企业级业务场景分析MongoDB的优缺点剖析开发环境搭建 5.1 JDK安装与配置5.2 MongoDB安装与集群配置5.3 开发工具选型 Java与MongoDB集成实战 6.1 项目依赖与驱动选择6.2 连接池与客户端配置6.3…...

构建一个测试助手Agent:提升测试效率的实践

在上一篇文章中,我们讨论了如何构建一个运维助手Agent。今天,我想分享另一个实际项目:如何构建一个测试助手Agent。这个项目源于我们一个大型互联网公司的真实需求 - 提升测试效率,保障产品质量。 从测试痛点说起 记得和测试团队讨论时的场景: 小张:每…...

ESXI虚拟机中部署docker会降低服务器性能

在 8 核 16GB 的 ESXi 虚拟机中部署 Docker 的性能影响分析 在 ESXi 虚拟机中运行 Docker 容器时,性能影响主要来自以下几个方面: 虚拟化开销:ESXi 虚拟化层和 Docker 容器化层的叠加。资源竞争:虚拟机与容器之间对 CPU、内存、…...

基于“蘑菇书”的强化学习知识点(五):条件期望

条件期望 摘要一、条件期望的定义二、条件期望的关键性质三、条件期望的直观理解四、条件期望的应用场景五、简单例子离散情况连续情况 摘要 本系列知识点讲解基于蘑菇书EasyRL中的内容进行详细的疑难点分析!具体内容请阅读蘑菇书EasyRL! 对应蘑菇书Eas…...

Linux抢占式内核:技术演进与源码解析

一、引言 Linux内核作为全球广泛使用的开源操作系统核心,其设计和实现一直是计算机科学领域的研究热点。从早期的非抢占式内核到2.6版本引入的抢占式内核,Linux在实时性和响应能力上取得了显著进步。本文将深入探讨Linux抢占式内核的引入背景、技术实现以及与非抢占式内核的…...

接入DeepSeek大模型

接入DeepSeek 下载并安装Ollamachatbox 软件配置大模型 下载并安装Ollama 下载并安装Ollama, 使用参数ollama -v查看是否安装成功。 输入命令ollama list, 可以看到已经存在4个目录了。 输入命令ollama pull deepseek-r1:1.5b, 下载deepse…...

【论文复现】粘菌算法在最优经济排放调度中的发展与应用

目录 1.摘要2.黏菌算法SMA原理3.改进策略4.结果展示5.参考文献6.代码获取 1.摘要 本文提出了一种改进粘菌算法(ISMA),并将其应用于考虑阀点效应的单目标和双目标经济与排放调度(EED)问题。为提升传统粘菌算法&#xf…...

SSM开发(十) SSM框架协同工作原理

目录 一、Spring扮演了一个整合者的角色 二、SSM拆解来看 三、SSM框架的核心优势 注: SSM框架(Spring + Spring MVC + MyBatis) 一、Spring扮演了一个整合者的角色 SSM框架中,Spring扮演了一个整合者的角色,它将Spring MVC的Web层和MyBatis的数据持久层连接起来。在SS…...

UE Bridge混合材质工具

打开虚幻内置Bridge 随便点个材质点右下角图标 就能打开材质混合工具 可以用来做顶点绘制...

基于 yolov8_pyqt5 自适应界面设计的火灾检测系统 demo:毕业设计参考

基于 yolov8_pyqt5 自适应界面设计的火灾检测系统 demo:毕业设计参考 【毕业设计参考】基于yolov8-pyqt5自适应界面设计的火灾检测系统demo.zip资源-CSDN文库 【毕业设计参考】基于yolov8-pyqt5自适应界面设计的火灾检测系统demo.zip资源-CSDN文库 一、项目背景 …...

Linux 传输层协议 UDP 和 TCP

UDP 协议 UDP 协议端格式 16 位 UDP 长度, 表示整个数据报(UDP 首部UDP 数据)的最大长度如果校验和出错, 就会直接丢弃 UDP 的特点 UDP 传输的过程类似于寄信 . 无连接: 知道对端的 IP 和端口号就直接进行传输, 不需要建立连接不可靠: 没有确认机制, 没有重传机制; 如果因…...

Android开发EventBus

Android开发EventBus 分享一个EventBus 工具类,封装一下,让你少写些代码 直接上代码: public class BaseEventBusUtils {public static void register(Object subscriber) {EventBus eventBus EventBus.getDefault();if (!eventBus.isReg…...

chrome浏览器chromedriver下载

chromedriver 下载地址 https://googlechromelabs.github.io/chrome-for-testing/ 上面的链接有和当前发布的chrome浏览器版本相近的chromedriver 实际使用感受 chrome浏览器会自动更新,可以去下载最新的chromedriver使用,自动化中使用新的chromedr…...

第一个Qt开发实例(一个Push Button按钮和两个Label)【包括如何在QtCreator中创建新工程、代码详解、编译、环境变量配置、测试程序运行等】

目录 Qt开发环境QtCreator的安装、配置在QtCreator中创建新工程在Forms→mainwindow.ui中拖曳出我们要的图形按钮查看拖曳出按钮后的代码为pushButton这个图形添加回调函数编译工程关闭开发板上QT的GUI(选做)禁止LCD黑屏(选做)设置Qt运行的环境变量运行Qt程序如何让程序在系统启…...

【react+redux】 react使用redux相关内容

首先说一下,文章中所提及的内容都是我自己的个人理解,是我理逻辑的时候,自我说服的方式,如果有问题有补充欢迎在评论区指出。 一、场景描述 为什么在react里面要使用redux,我的理解是因为想要使组件之间的通信更便捷…...

【435. 无重叠区间 中等】

题目: 给定一个区间的集合 intervals ,其中 intervals[i] [starti, endi] 。返回 需要移除区间的最小数量,使剩余区间互不重叠 。 注意 只在一点上接触的区间是 不重叠的。例如 [1, 2] 和 [2, 3] 是不重叠的。 示例 1: 输入: intervals …...

文献学习笔记:中风醒脑液(FYTF-919)临床试验解读:有效还是无效?

【中风醒脑液(FYTF-919)临床试验解读:有效还是无效?】 在发表于 The Lancet (2024 年 11 月 30 日,第 404 卷)的临床研究《Traditional Chinese medicine FYTF-919 (Zhongfeng Xingnao oral pr…...

4 前端前置技术(中):node.js环境

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言 前言...

5.角色基础移动

能帮到你的话,就给个赞吧 😘 文章目录 角色的xyz轴与移动方向拌合输入轴值add movement inputget controller rotationget right vectorget forward vector 发现模型的旋转改变后,xyz轴也会改变,所以需要旋转值来计算xyz轴方向。 …...