当前位置: 首页 > news >正文

梯度下降法、模拟训练、拟合二次曲线、最小二乘法、MSELoss、拟合:f(x)=ax^2+bx+c

本文目标:

f(x)=a*x^2+b*x+c

以这个公式为例,设计一个算法,用梯度下降法来模拟训练过程,最终得出参数a,b,c

原理介绍

目标函数:h_{\theta}(x) = a^{2}+bx+c

损失函数:Loss=\frac{1}{2m}\sum_{1}^{m}(h_{\theta}(x^{i})-y^{i})^2,就是mse

损失函数展开:Loss=\frac{1}{2m}\sum_{1}^{m}((ax^{i})^{2}+bx^i+c-y^{i})^2

损失函数对a,b,c求导数:

{L_{a}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)*x^2

{L_{b}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)*x

{L_{c}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)

导数就是梯度,也就是目标参数与当前参数的差异,这个差异需要用梯度下降法更新

\Delta a{L_{a}}^{'}   \Delta b={L_{b}}^{'}     \Delta c={L_{c}}^{'}

a = a - lr*\Delta a

b = b - lr*\Delta b

c = c - lr*\Delta c

重复上面的过程,参数就可以更新了,然后可以看看新参数的效果,也就是损失有没有降低

具体流程

  1. 预设模型的表达式为:f(x)=a*x^2+b*x+c,也就是二次函数。同时随机初始化模型参数a,b,c。如果是其他函数如f(x)=ax^3+bx^2+cx,就无法在本版本适用(修改求导方式后才可用)。即本模型需要提前知道模型的表达式。
  2. 通过不断喂入(x_input,y_true),得出y_{out} = ax_{input}^{2}+bx_{input}+c.而y_out与y_true之间具有差异。
  3. 将差异封装成一个loss函数,并分别对a,b,c进行求导。得到a,b,c的梯度\Delta a\Delta b\Delta c
  4. \Delta a\Delta b\Delta c和原始的参数a,b,c和学习率作为输入,用梯度下降法来对a,b,c参数进行更新.
  5. 重复2,3,4过程。直到训练结束或者loss降低到较小值

python实现

  # 初始化a,b,c为:-11/6 , -395/3,-2400  目标a,b,c为:(2,-4,3)

class QuadraticFunc():def drew(self,w,name="show"):a,b,c = wx1 = np.array(range(-80,80))y1 = a*x1*x1 + b*x1 + cy2 = 2*x1*x1 - 4*x1 + 3plt.clf()plt.plot(x1, y1)plt.plot(x1, y2)plt.scatter(x1, y1, c='r')# set colorplt.xlim((-50,50))plt.ylim((-500,500))plt.xlabel('X Axis')plt.ylabel('Y Axis')if name == "first":plt.pause(3)else:plt.pause(0.01)plt.ioff()#计算lossdef cal_loss(self,y_out,y_true):# return np.dot((y_out - y_true),(y_out - y_true)) * 0.5return np.mean((y_out - y_true)**2)#计算梯度  def cal_grad(self,x,y_out,y_true):# x(batch),y_out(batch),y_true(batch)a_grad = (y_out-y_true)*x**2 #b_grad = (y_out-y_true)*xc_grad = (y_out-y_true)return np.array([np.mean(a_grad),np.mean(b_grad),np.mean(c_grad)])        #梯度下降法更新参数def update_theta(self,step,w,grad):new_w = w - step*gradreturn new_wdef run(self):feed_x = np.array(range(-400,400))/400feed_y = 2*feed_x*feed_x - 4*feed_x + 3step = 0.5base_lr = 0.5lr = base_lr# 初始化参数a,b,c = -11/6 , -395/3,-2400#-1,10,26w = np.array([a,b,c])self.drew(w,"first")epochs = 100for epoch in range(epochs):# 每隔10轮 降低一半的学习率lr = base_lr/(2**(int((epoch+1)/10)))for i in range(len(feed_x)):x_input = feed_x[i]y_true = feed_y[i]y_out = w[0]*x_input*x_input +w[1]*x_input + w[2]#计算lossloss = self.cal_loss(y_out,y_true)#计算梯度grad = self.cal_grad(x_input,y_out,y_true)#更新参数,梯度下降w = self.update_theta(lr,w,grad)# self.drew(w)grad = np.round(grad,2)loss = np.round(loss,2)w = np.round(w,2)print("train times is:",epoch,"  grad is:",grad,"   loss is:","%.4f"%loss, "  w is:",w,"\n")self.drew(w)if loss<1e-5:print("train finish:",w)breakdef run_batch(self):feed_x = np.array(range(-400,400))/400feed_y = 2*feed_x*feed_x - 4*feed_x + 3x_y = [[feed_x[i],feed_y[i]] for i in range(len(feed_x))]base_lr = 0.5lr = base_lr# 初始化参数a,b,c = -11/6 , -395/3,-2400#-1,10,26w = np.array([a,b,c])self.drew(w,"first")batch_size = 16data_len = len(x_y)//batch_sizeepochs = 100for epoch in range(epochs):random.shuffle(x_y)# 每隔10轮 降低一半的学习率lr = base_lr/(2**(int((epoch+1)/10)))print("epoch,lr:",epoch,lr)for i in range(data_len):x_y_list = x_y[i*batch_size:(i+1)*batch_size]x_y_np = np.array(x_y_list)x_input = x_y_np[:,0]y_true = x_y_np[:,1]y_out = w[0]*x_input*x_input +w[1]*x_input + w[2]#计算lossloss = self.cal_loss(y_out,y_true)#计算梯度grad = self.cal_grad(x_input,y_out,y_true)#更新参数,梯度下降w = self.update_theta(lr,w,grad)grad = np.round(grad,2)loss = np.round(loss,2)w = np.round(w,2)print("train times is:",epoch,"  grad is:",grad,"   loss is:","%.4f"%loss, "  w is:",w,"\n")self.drew(w)if loss<1e-5:print("train finish:",w)# breaktime.sleep(0.1)if __name__ == "__main__":qf = QuadraticFunc()qf.run()

相关文章:

梯度下降法、模拟训练、拟合二次曲线、最小二乘法、MSELoss、拟合:f(x)=ax^2+bx+c

本文目标&#xff1a; 以这个公式为例&#xff0c;设计一个算法&#xff0c;用梯度下降法来模拟训练过程&#xff0c;最终得出参数a,b,c 原理介绍 目标函数&#xff1a; 损失函数&#xff1a;&#xff0c;就是mse 损失函数展开&#xff1a; 损失函数对a,b,c求导数: 导数就是梯度…...

Web3.0投票如何做到公平公正且不泄露个人隐私

在当前的数字时代&#xff0c;社交平台举办投票活动已成为了一种普遍现象。然而&#xff0c;随之而来的是一些隐私和安全方面的顾虑&#xff0c;特别是关于个人信息泄露和电话骚扰的问题。期望建立一个既公平公正又能保护个人隐私的投票系统。Web3.0的出现为实现这一目标提供了…...

灰度图像的自动阈值分割

第一种&#xff1a;Otsu &#xff08;大津法&#xff09; 一、基于cv2的API调用 1、代码实现 直接给出相关代码&#xff1a; import cv2 import matplotlib.pylab as pltpath r"D:\Desktop\00aa\1.png" img cv2.imread(path, 0)def main2():ret, thresh1 cv2.…...

利用Maven获取jar包

我有一个习惯&#xff0c;就是程序不在线依赖网络的任何包。以前用C#时候虽然用Nuget找包&#xff0c;但是添加引用后又马上把Nuget引用删了&#xff0c;再把Nuget下载的dll拷贝到工程再引用dll。 这样做的好处是&#xff1a; 1.别人得到程序代码可以直接编译&#xff0c;不用…...

将vue组件发布成npm包

文章目录 前言一、环境准备1.首先最基本的需要安装nodejs&#xff0c;版本推荐 v10 以上&#xff0c;因为需要安装vue-cli2.安装vue-cli 二、初始化项目1.构建项目2.开发组件/加入组件3. 修改配置文件 三、调试1、执行打包命令2、发布本地连接包3、测试项目 四、发布使用1、注册…...

江科大STM32 中

目录 6、TIM&#xff08;Timer&#xff09;定时器基本定时器通用定时器高级定时器示例程序&#xff08;定时器定时中断&定时器外部时钟&#xff09;TIM输出比较示例程序&#xff08;PWM驱动LED呼吸灯&PWM驱动舵机&PWM驱动直流电机&#xff09;TIM输入捕获示例程序&…...

vue+draggable+el-upload上传图片拖拽重排方法

vuedraggableel-upload上传图片拖拽重排方法 1.html <el-row><el-col><el-form-item label"添加视频/图片" prop"device_id"><div class"image-upload"><draggable v-model"fileList" update"dataDr…...

微信的新版canvas绘制的图案发生变形和偏移的问题

一,现象 this.context.beginPath(); this.context.moveTo(10, 10); this.context.lineTo(10, 100); this.context.lineTo(100, 100); this.context.lineTo(100, 10); this.context.lineTo(10, 10); this.context.stroke();本来绘制的是正方形,结果绘制出来是个矩形,边的宽度也…...

[ACM学习] 进制转换

进制的本质 本质是每一位的数位上的数字乘上这一位的权重 将任意进制转换为十进制 原来还很疑惑为什么从高位开始&#xff0c;原来从高位开始的&#xff0c;可以被滚动地乘很多遍。 将十进制转换为任意进制...

redis + 拦截器 :防止数据重复提交

1.项目用到,不是核心 我们干系统开发,不免要考虑一个点&#xff0c;数据的重复提交。 我想我们之前如果要校验数据重复提交要求&#xff0c;会怎么干?会在业务层&#xff0c;对数据库操作&#xff0c;查询数据是否存在,存在就禁止插入数据; 但是吧,我们每次crud操作都会连接…...

如何进行H.265视频播放器EasyPlayer.js的中性化设置?

H5无插件流媒体播放器EasyPlayer属于一款高效、精炼、稳定且免费的流媒体播放器&#xff0c;可支持多种流媒体协议播放&#xff0c;可支持H.264与H.265编码格式&#xff0c;性能稳定、播放流畅&#xff0c;能支持WebSocket-FLV、HTTP-FLV&#xff0c;HLS&#xff08;m3u8&#…...

Ubuntu22.04安装4090显卡驱动

1、安装完Ubuntu系统&#xff0c;打完所有补丁后再进行后续操作 2、下载系统所需要的版本的NV显卡驱动&#xff0c;本次由于使用CUDA12.1&#xff0c;故选用的驱动版本为NVIDIA-Linux-x86_64-530.41.03.run 3、卸载NV驱动&#xff08;只是保险起见&#xff0c;并不是一定会卸…...

YOLOv8优化策略:注意力涨点系列篇 | 一种轻量级的加强通道信息和空间信息提取能力的MLCA注意力

🚀🚀🚀本文改进:一种轻量级的加强通道信息和空间信息提取能力 MLCA注意力 🚀🚀🚀在YOLOv8中如何使用 1)作为注意力机制使用;2)与c2f结合使用; 🚀🚀🚀YOLOv8改进专栏:http://t.csdnimg.cn/hGhVK 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研…...

【新书推荐】2.5节 有符号整数和无符号整数

本节内容&#xff1a;整数的编码规则。 ■数据的编码规则&#xff1a;计算机的二进制数对于计算机本身而言仅仅表示0和1。人们按照不同的编码规则赋予二进制数不同的含义。整数的编码规则分为有符号整数和无符号整数。 ■数据的存储规则&#xff1a;x86计算机以字节为单位&…...

RT-Thread: 串口操作、增加串口、串口函数

说明&#xff1a;本文记录RT-Thread添加串口的步骤和串口的使用。 1.新增串口 官方链接&#xff1a;https://www.rt-thread.org/document/site/rtthread-studio/drivers/uart/v4.0.2/rtthread-studio-uart-v4.0.2/ 新增串口只需要在 board.h 文件中定义相关串口的宏定…...

自然语言处理的新突破:如何推动语音助手和机器翻译的进步

一、语音助手方面的进展 语音助手作为人机交互的重要入口之一,其性能的提升离不开自然语言处理技术的进步。基于深度学习的语音识别和语义理解技术,使得语音助手可以更准确地分析用户意图,提供个性化服务。 语音识别精度的持续提高 语音识别是语音助手的基础。随着深度神经网…...

vue3 + jeecgBoot 获取项目IP地址

封装的useGlobSetting 函数 引入并使用 import { useGlobSetting } from //hooks/setting;const glob useGlobSetting();console.log(glob.uploadUrl) //http://192.168.105.57:7900/bs-axfd...

Java Server-Sent Events通信

Server-Sent Events特点与优势 后端可以向前端发送信息&#xff0c;类似于websocket&#xff0c;但是websocket是双向通信&#xff0c;但是sse为单向通信&#xff0c;服务器只能向客户端发送文本信息&#xff0c;效率比websocket高。 单向通信&#xff1a;SSE只支持服务器到客…...

[蓝桥杯]真题讲解:冶炼金属(暴力+二分)

蓝桥杯真题视频讲解&#xff1a;冶炼金属&#xff08;暴力做法与二分做法&#xff09; 一、视频讲解二、暴力代码三、正解代码 一、视频讲解 视频讲解 二、暴力代码 //暴力代码 #include<bits/stdc.h> #define endl \n #define deb(x) cout << #x << &qu…...

Fastbee开源物联网项目RoadMap

架构优化 代码简化业务&协议解耦关键组件支持横向拓展网络协议支持横向拓展&#xff0c;包括&#xff1a;mqtt broker,tcp,coap,udp,sip等协议插件化编码脚本化业务代码模版化消息总线 功能优化 网关/子网关&#xff1a;上线&#xff0c;绑定&#xff0c;拓扑&#xff0…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

UDP(Echoserver)

网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法&#xff1a;netstat [选项] 功能&#xff1a;查看网络状态 常用选项&#xff1a; n 拒绝显示别名&#…...

稳定币的深度剖析与展望

一、引言 在当今数字化浪潮席卷全球的时代&#xff0c;加密货币作为一种新兴的金融现象&#xff0c;正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而&#xff0c;加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下&#xff0c;稳定…...

Python Ovito统计金刚石结构数量

大家好,我是小马老师。 本文介绍python ovito方法统计金刚石结构的方法。 Ovito Identify diamond structure命令可以识别和统计金刚石结构,但是无法直接输出结构的变化情况。 本文使用python调用ovito包的方法,可以持续统计各步的金刚石结构,具体代码如下: from ovito…...

免费PDF转图片工具

免费PDF转图片工具 一款简单易用的PDF转图片工具&#xff0c;可以将PDF文件快速转换为高质量PNG图片。无需安装复杂的软件&#xff0c;也不需要在线上传文件&#xff0c;保护您的隐私。 工具截图 主要特点 &#x1f680; 快速转换&#xff1a;本地转换&#xff0c;无需等待上…...

C++.OpenGL (20/64)混合(Blending)

混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...

Chromium 136 编译指南 Windows篇:depot_tools 配置与源码获取(二)

引言 工欲善其事&#xff0c;必先利其器。在完成了 Visual Studio 2022 和 Windows SDK 的安装后&#xff0c;我们即将接触到 Chromium 开发生态中最核心的工具——depot_tools。这个由 Google 精心打造的工具集&#xff0c;就像是连接开发者与 Chromium 庞大代码库的智能桥梁…...

6.9-QT模拟计算器

源码: 头文件: widget.h #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QMouseEvent>QT_BEGIN_NAMESPACE namespace Ui { class Widget; } QT_END_NAMESPACEclass Widget : public QWidget {Q_OBJECTpublic:Widget(QWidget *parent nullptr);…...

数据结构第5章:树和二叉树完全指南(自整理详细图文笔记)

名人说&#xff1a;莫道桑榆晚&#xff0c;为霞尚满天。——刘禹锡&#xff08;刘梦得&#xff0c;诗豪&#xff09; 原创笔记&#xff1a;Code_流苏(CSDN)&#xff08;一个喜欢古诗词和编程的Coder&#x1f60a;&#xff09; 上一篇&#xff1a;《数据结构第4章 数组和广义表》…...