当前位置: 首页 > news >正文

【一起撸个DL框架】5 实现:自适应线性单元

  • CSDN个人主页:清风莫追
  • 欢迎关注本专栏:《一起撸个DL框架》
  • GitHub获取源码:https://github.com/flying-forever/OurDL
  • blibli视频合集:https://space.bilibili.com/3493285974772098/channel/series

文章目录

  • 5 实现:自适应线性单元🍇
    • 1 简介
    • 2 损失函数
      • 2.1 梯度下降法
      • 2.2 补充
    • 3 整理项目结构
    • 4 损失函数的实现
    • 5 修改节点类(Node)
    • 6 自适应线性单元

5 实现:自适应线性单元🍇

1 简介

上一篇:【一起撸个DL框架】4 反向传播求梯度

上一节我们实现了计算图的反向传播,可以求结果节点关于任意节点的梯度。下面我们将使用梯度来更新参数,实现一个简单的自适应线性单元

我们本次拟合的目标函数是一个简单的线性函数: y = 2 x + 1 y=2x+1 y=2x+1,通过随机数生成一些训练数据,将许多组x和对应的结果y值输入模型,但是并不告诉模型具体函数中的系数参数“2”和偏置参数“1”,看看模型能否通过数据“学习”到参数的值。

图1:自适应线性单元的计算图

2 损失函数

2.1 梯度下降法

损失是对模型好坏的评价指标,表示模型输出结果与正确答案(也称为标签)之间的差距。所以损失值越小就说明模型越准确,训练过程的目的便是最小化损失函数的值

自适应线性单元是一个回归任务,我们这里将使用绝对值损失,将模型输出与正确答案之间的差的绝对值作为损失函数的值,即 l o s s = ∣ l − a d d ∣ loss=|l-add| loss=ladd

评价指标有了,可是如何才能达标呢?或者说如何才能降低损失函数的值?计算图中有四个变量: x , w , b , l x,w,b,l x,w,b,l,而我们训练过程的任务是调整参数 w , b w,b w,b的值,以降低损失。因此训练过程中的自变量是w和b,而把x和l看作常量。此时损失函数是关于w和b的二元函数 l o s s = f ( w , b ) loss=f(w,b) loss=f(w,b),我们只需要求函数的梯度 ▽ f ( w , b ) = ( ∂ f ∂ w , ∂ f ∂ b ) \triangledown f(w,b)=(\frac{\partial f}{\partial w},\frac{\partial f}{\partial b}) f(w,b)=(wf,bf),则梯度的反方向就是函数下降最快的方向。沿着梯度的方向更新参数w和b的值,就可以降低损失。这就是经典的优化算法:梯度下降法

2.2 补充

关于损失和优化的概念,大家可能还是有些模糊。上面损失只讲到了一个输入x值对应的模型输出与实际结果之间的差距,但使用整个数据集的平均差距可能更容易理解,就像中学的线性回归

图2所示,改变直线的斜率w,将改变直线与数据点的贴近程度,即改变了损失函数loss的值。

在这里插入图片描述
图2:损失与参数更新示意图

参考: 【深度学习】3-从模型到学习的思路整理_清风莫追的博客-CSDN博客

3 整理项目结构

我们的小项目的代码也渐渐多起来了,好的目录结构将使它更加易于扩展。关于python包结构的知识大家可以自行去了解,大致目录结构如下:

- example
- ourdl- core- __init__.py- node.py- ops- __init__.py- loss.py- ops.py__init__.py

给这个简单框架的名字叫做OurDL,使用框架搭建的计算图等程序放在example目录下。在ourdl/core/node.py中存放了节点基类和变量类的定义,在ourdl/ops/下存放了运算节点的定义,包括损失函数和加法、乘法节点等。

4 损失函数的实现

/ourdl/ops/loss.py中,

from ..core import Nodeclass ValueLoss(Node):'''损失函数:作差取绝对值'''def compute(self):self.value = self.parent1.value - self.parent2.valueself.flag = self.value > 0if not self.flag:self.value = -self.valuedef get_parent_grad(self, parent):a = 1 if self.flag else -1b = 1 if parent == self.parent1 else -1return a * b

其中compute()方法很显然就是对两个输入作差取绝对值;get_parent_grad()方法求本节点关于父节点的梯度。有绝对值如何求梯度?大家可以画一画绝对值函数的图像。

5 修改节点类(Node)

ourdl/core/node.py

class Node:pass  # 省略了一些方法的定义,大家可以查看上一篇文章def clear(self):'''递归清除父节点的值和梯度信息'''self.grad = Noneif self.parent1 is not None:  # 清空非变量节点的值self.value = Nonefor parent in [self.parent1, self.parent2]:if parent is not None:parent.clear()def update(self, lr=0.001):'''根据本节点的梯度,更新本节点的值'''self.value -= lr * self.grad  # 减号表示梯度的反方向

我在节点类中新增了两个方法,其中clear()用于清除多余的节点值和梯度信息,因为当节点值或梯度已经存在时会直接返回结果而不会递归去求了(get_grad()forward()的代码)。update()有一个学习率参数lr,更新幅度太大可能导致参数值一直在目标值左右晃悠,无法收敛

6 自适应线性单元

/example/01_esay/自适应线性单元.py

import sys
sys.path.append('../..')
from ourdl.core import Varrible
from ourdl.ops import Mul, Add
from ourdl.ops.loss import ValueLossif __name__ == '__main__':# 搭建计算图x = Varrible()w = Varrible()mul = Mul(parent1=x, parent2=w)b = Varrible()add = Add(parent1=mul, parent2=b)label = Varrible()loss = ValueLoss(parent1=label, parent2=add)# 参数初始化w.set_value(0)b.set_value(0)# 生成训练数据import randomdata_x = [random.uniform(-10, 10) for i in range(10)]  # 按均匀分布生成[-10, 10]范围内的随机实数data_label = [2 * data_x_one + 1 for data_x_one in data_x]# 开始训练for i in range(len(data_x)):x.set_value(data_x[i])label.set_value(data_label[i])loss.forward()  # 前向传播 --> 求梯度会用到损失函数的值w.get_grad()b.get_grad()w.update(lr=0.05)b.update(lr=0.1)loss.clear()print("w:{:.2f}, b:{:.2f}".format(w.value, b.value))print("最终结果:{:.2f}x+{:.2f}".format(w.value, b.value))

运行结果:

w:0.13, b:0.10
w:0.36, b:0.20
w:0.58, b:0.10
w:0.74, b:0.00
w:1.13, b:0.10
w:1.43, b:0.20
w:1.62, b:0.30
w:1.94, b:0.20
w:1.50, b:0.30
w:1.87, b:0.40
最终结果:1.87x+0.40

上面自适应线性单元的训练,已经能够大致展现深度学习模型的训练流程:

  • 搭建模型 --> 初始化参数 --> 准备数据 --> 使用数据更新参数的值

我们这里参数只更新了10次,结果就已经大致接近了我们的目标函数 y = 2 x + 1 y=2x+1 y=2x+1。大家可以试试更改学习率lr,训练数据集的大小,观察运行结果会发生怎样的变化。(必备技能:调参)


下一篇:【一起撸个深度学习框架】6 折与曲的相会——激活函数

相关文章:

【一起撸个DL框架】5 实现:自适应线性单元

CSDN个人主页:清风莫追欢迎关注本专栏:《一起撸个DL框架》GitHub获取源码:https://github.com/flying-forever/OurDLblibli视频合集:https://space.bilibili.com/3493285974772098/channel/series 文章目录 5 实现:自适…...

开箱即用的工具函数库xijs更新指南(v1.2.6)

xijs 是一款开箱即用的 js 业务工具库, 聚集于解决业务中遇到的常用函数逻辑问题, 帮助开发者更高效的开展业务开发. 接下来就和大家一起分享一下 v1.2.6 版本的更新内容以及后续的更新方向. 贡献者列表: 1. 计算变量内存calculateMemory 该模块主要由 zhengsixsix 贡献, 我们可…...

【Netty】ChannelPipeline源码分析(五)

文章目录 前言一、ChannelPipeline 接口1.1 创建 ChannelPipeline1.2 ChannelPipeline 事件传输机制1.2.1 处理出站事件1.2.2 处理入站事件 二、ChannelPipeline 中的 ChannelHandler三、ChannelHandlerContext 接口3.1 ChannelHandlerContext 与其他组件的关系3.2 跳过某些 Ch…...

并行计算技术解密:MPI和OpenMP的学习和应用指南

欢迎来到并行计算技术的奇妙世界!本指南将带您深入了解MPI(Message Passing Interface)和OpenMP(Open Multi-Processing)两种重要的并行计算技术,并为您提供学习和应用的指南。无论您是一个科研工作者、开发…...

什么是自动化测试框架?我们该如何搭建自动化测试框架?

无论是在自动化测试实践,还是日常交流中,经常听到一个词:框架。之前学习自动化测试的过程中,一直对“框架”这个词知其然不知其所以然。 最近看了很多自动化相关的资料,加上自己的一些实践,算是对“框架”…...

Debezium报错处理系列之六十七:TopicAuthorizationException: Not authorized to access topics

Debezium报错处理系列之六十七:TopicAuthorizationException: Not authorized to access topics 一、完整报错二、错误原因三、解决方法Debezium报错处理系列一:The db history topic is missing. Debezium报错处理系列二:Make sure that the same history topic isn‘t sha…...

javaWebssh中小学课件资源系统myeclipse开发mysql数据库MVC模式java编程计算机网页设计

一、源码特点 java ssh中小学课件资源系统是一套完善的web设计系统(系统采用ssh框架进行设计开发),对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用 B/S模式开发。开发环境为TOMCAT…...

MySQL高级查询操作

文章目录 前言聚集函数分组查询:GROUP BY过滤:HAVING嵌套子查询比较运算中使用子查询带有IN的子查询SOME(子查询)ALL(子查询)EXISTS子查询 前言 查询语句书写顺序: 1、select 2、from 3、where 4、group by 5、having 6、order by 7、limit …...

Day53【动态规划】1143.最长公共子序列、1035.不相交的线、53.最大子序和

1143.最长公共子序列 力扣题目链接/文章讲解 视频讲解 本题最大的难点还是定义 dp 数组 本题和718.最长重复子数组区别在于这里不要求是连续的了,但要有相对顺序 直接动态规划五部曲! 1、确定 dp 数组下标及值含义 dp[i][j]:取 text1…...

Three.js--》实现3d地球模型展示

目录 项目搭建 实现网页简单布局 初始化three.js基础代码 创建环境背景 加载地球模型 实现光柱效果 添加月球模型 今天简单实现一个three.js的小Demo,加强自己对three知识的掌握与学习,只有在项目中才能灵活将所学知识运用起来,话不多…...

<SQL>《SQL命令(含例句)精心整理版(6)》

《SQL命令(含例句)精心整理版(6)》 18 DB2查询语句18.1 查询数据库大小18.2 查看表占表空间大小18.3 查看正在执行的语句18.4 db2expln 查看执行计划18.5 db2advis 查看优化建议 19 空值19.1 NULL19.2 TRIM 18 DB2查询语句 18.1 …...

信息系统建设和服务能力评估证书CS

信息系统建设和服务能力评估体系CS简介 简介:本标准(团标T/CITIF 001-2019)是信息系统建设和服务能力评估体系系列标准的第一个,提出了对信息系统建设和服务提供者的综合能力要求。 发证单位:中国电子信息行业联合会。…...

vue3引入路由

1.首先在项目中安装路由 npm install vue-router -S 2.src文件夹下新建》views文件夹》新建home文件夹》新建Home.vue文件 在src文件夹下》新建router文件夹》新建index.js import { createRouter,createWebHashHistory } from vue-router const route s[ { path:/, compo…...

前后端联调跨域问题

文章目录 什么是同源策略如何判断是否同源?跨域资源共享(CORS)如何解决跨域问题 什么是同源策略 同源策略限制了从同一个源加载的文档或脚本如何与来自另一个源的资源进行交互。这是一个用于隔离潜在恶意文件的重要安全机制。 如何判断是否同源? 如果…...

day11 - 手写数字笔迹细化

手写数字笔迹细化 对于手写数字识别实验中,经常会遇到因为笔迹较粗导致误识别的情况,所以我们通常会先将笔迹进行细化,笔迹变细以后,数字的特征会更明显,后续进行识别的准确率就会更高。 例如数字7 和 1 &#xff0c…...

C++ QT QDBus基操

以下是使用QDBus进行跨进程通信的具体用法&#xff1a; 1. 创建DBus服务 在服务端进程中&#xff0c;需要创建一个DBus服务&#xff0c;并注册DBus对象。示例代码如下&#xff1a; #include <QDBusConnection> #include <QDBusMessage> #include <QDBusInterf…...

STM32的SPI外设

文章目录 1. STM32 的 SPI 外设简介2. STM32 的 SPI 架构剖析2.1 通讯引脚2.2 时钟控制逻辑2.3 数据控制逻辑2.4 整体控制逻辑 3. 通讯过程4. SPI 初始化结构体详解 1. STM32 的 SPI 外设简介 STM32 的 SPI 外设可用作通讯的主机及从机&#xff0c;支持最高的 SCK 时钟频率为 …...

VMWare ESXI6.7创建虚拟机

VMware ESXi&#xff1a;专门构建的裸机 管理程序 首先开启ESXI主机 登录ESXI 打开浏览器输入物理机ip&#xff0c;输入账号密码进行登录 创建虚拟机 选择创建类型 创建RedHat7.6 选择存储类型和数据存储 仅一个存储&#xff0c;直接点下一页即可 配置虚拟机硬件和虚拟机附…...

TensorFlow 1.x学习(系列二 :4):自实现线性回归

目录 线性回归基本介绍常用的op自实现线性回归预测tensorflow 变量作用域模型的保存和加载 线性回归基本介绍 线性回归&#xff1a; w 1 ∗ x 1 w 2 ∗ x 2 w 3 ∗ x 3 . . . w n ∗ x n b i a s w_1 * x_1 w_2 * x_2 w_3 * x_3 ... w_n * x_n bias w1​∗x1​w2​∗…...

Openwrt折腾记6-网络摄像头

前言&#xff1a; 前几天买了个电视机上的摄像头&#xff0c;但是估计是电视配置或软件不好&#xff0c;视频通话太卡顿。今天把它装的极路由4的usb上了。由于当初挑的是电视免驱的&#xff0c;所以我猜想是通用的芯片。 调查驱动 LINUX uvc支持型号的列表里 http://www.ide…...

从零实现富文本编辑器#5-编辑器选区模型的状态结构表达

先前我们总结了浏览器选区模型的交互策略&#xff0c;并且实现了基本的选区操作&#xff0c;还调研了自绘选区的实现。那么相对的&#xff0c;我们还需要设计编辑器的选区表达&#xff0c;也可以称为模型选区。编辑器中应用变更时的操作范围&#xff0c;就是以模型选区为基准来…...

前端倒计时误差!

提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

【HarmonyOS 5.0】DevEco Testing:鸿蒙应用质量保障的终极武器

——全方位测试解决方案与代码实战 一、工具定位与核心能力 DevEco Testing是HarmonyOS官方推出的​​一体化测试平台​​&#xff0c;覆盖应用全生命周期测试需求&#xff0c;主要提供五大核心能力&#xff1a; ​​测试类型​​​​检测目标​​​​关键指标​​功能体验基…...

大语言模型如何处理长文本?常用文本分割技术详解

为什么需要文本分割? 引言:为什么需要文本分割?一、基础文本分割方法1. 按段落分割(Paragraph Splitting)2. 按句子分割(Sentence Splitting)二、高级文本分割策略3. 重叠分割(Sliding Window)4. 递归分割(Recursive Splitting)三、生产级工具推荐5. 使用LangChain的…...

MMaDA: Multimodal Large Diffusion Language Models

CODE &#xff1a; https://github.com/Gen-Verse/MMaDA Abstract 我们介绍了一种新型的多模态扩散基础模型MMaDA&#xff0c;它被设计用于在文本推理、多模态理解和文本到图像生成等不同领域实现卓越的性能。该方法的特点是三个关键创新:(i) MMaDA采用统一的扩散架构&#xf…...

【决胜公务员考试】求职OMG——见面课测验1

2025最新版&#xff01;&#xff01;&#xff01;6.8截至答题&#xff0c;大家注意呀&#xff01; 博主码字不易点个关注吧,祝期末顺利~~ 1.单选题(2分) 下列说法错误的是:&#xff08; B &#xff09; A.选调生属于公务员系统 B.公务员属于事业编 C.选调生有基层锻炼的要求 D…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

学校时钟系统,标准考场时钟系统,AI亮相2025高考,赛思时钟系统为教育公平筑起“精准防线”

2025年#高考 将在近日拉开帷幕&#xff0c;#AI 监考一度冲上热搜。当AI深度融入高考&#xff0c;#时间同步 不再是辅助功能&#xff0c;而是决定AI监考系统成败的“生命线”。 AI亮相2025高考&#xff0c;40种异常行为0.5秒精准识别 2025年高考即将拉开帷幕&#xff0c;江西、…...