当前位置: 首页 > article >正文

神经网络常见层Numpy封装参考(4):优化器

目录前置层优化器SGD优化器Adam优化器测试演示完整代码下载 神经网络常见层Numpy封装参考 - 常见层前置层- 神经网络常见层Numpy封装参考1损失层- 神经网络常见层Numpy封装参考2线性层- 神经网络常见层Numpy封装参考3激活层优化器SGD优化器classSGD:SGD优化器支持动量和权重衰减def__init__(self,params:List[Parameter],lr:float0.01,momentum:float0.0,weight_decay:float0.0,nesterov:boolFalse): Args: params: 参数列表 lr: 学习率 momentum: 动量因子 weight_decay: 权重衰减系数 nesterov: 是否使用Nesterov动量 self.paramsparams self.lrlr self.momentummomentum self.weight_decayweight_decay self.nesterovnesterov self.state{}# 存储每个参数的状态如动量缓冲区defstep(self):更新参数foridx,paraminenumerate(self.params):ifparam.gradisNone:continue# 获取梯度gradparam.grad.copy()# 权重衰减ifself.weight_decay!0:gradgradself.weight_decay*param.data# 获取或创建动量缓冲区ifself.momentum!0:ifidxnotinself.state:self.state[idx]{momentum_buffer:np.zeros_like(param.data)}momentum_bufferself.state[idx][momentum_buffer]# 更新动量缓冲区momentum_bufferself.momentum*momentum_buffer-self.lr*gradifself.nesterov:# Nesterov动量: 使用校正后的梯度param.dataself.momentum*momentum_buffer-self.lr*gradelse:param.datamomentum_buffer self.state[idx][momentum_buffer]momentum_bufferelse:# 标准SGDparam.data-self.lr*graddefzero_grad(self):forparaminself.params:param.zero_grad()def__repr__(self):returnself.__class__.__name__()Adam优化器classAdam:Adam优化器支持权重衰减def__init__(self,params:List[Parameter],lr:float0.001,betas:tuple(0.9,0.999),eps:float1e-8,weight_decay:float0.0): Args: params: 参数列表 lr: 学习率 betas: 动量衰减系数 (beta1, beta2) eps: 数值稳定项 weight_decay: 权重衰减系数 self.paramsparams self.lrlr self.betasbetas self.epseps self.weight_decayweight_decay self.state{}# 存储每个参数的状态一阶动量、二阶动量、时间步self.t0# 全局时间步用于偏差校正defstep(self):更新参数self.t1# 每调用一次 step时间步加 1beta1,beta2self.betasforidx,paraminenumerate(self.params):ifparam.gradisNone:continue# 获取梯度gradparam.grad.copy()# 权重衰减L2 正则化ifself.weight_decay!0:gradgradself.weight_decay*param.data# 初始化状态如果尚未初始化ifidxnotinself.state:self.state[idx]{m:np.zeros_like(param.data),# 一阶动量v:np.zeros_like(param.data)# 二阶动量}# 取出动量缓冲区mself.state[idx][m]vself.state[idx][v]# 更新一阶动量带指数衰减mbeta1*m(1-beta1)*grad# 更新二阶动量带指数衰减vbeta2*v(1-beta2)*(grad**2)# 偏差校正m_hatm/(1-beta1**self.t)v_hatv/(1-beta2**self.t)# 更新参数param.data-self.lr*m_hat/(np.sqrt(v_hat)self.eps)# 保存更新后的动量self.state[idx][m]m self.state[idx][v]vdefzero_grad(self):清零所有参数的梯度forparaminself.params:param.zero_grad()def__repr__(self):返回优化器的字符串表示returnself.__class__.__name__()测试%matplotlib inline# 数据点满足非线性关系xnp.linspace(0,1,500).reshape(-1,1)y2*x**30.25*np.random.randn(500).reshape(-1,1)plt.scatter(x,y)plt.show()​​# 定义模型和损失函数modelSequential(Linear(1,10),Tanh(),Linear(10,1),)criterionMSELoss()# 定义优化器optimizer1SGD(paramsmodel.parameters(),lr1e-2,momentum0.8,weight_decay0)optimizer2Adam(paramsmodel.parameters(),lr1e-2)# 训练流程forepochinrange(1000):y_predmodel(x)loss,gradcriterion(y_pred,y)model.backward(grad)optimizer2.step()optimizer2.zero_grad()if(epoch%2000):print(loss)1.1548328265032672 0.11623908399811281 0.10961048087231211 0.07929787391344985 0.06122113711571879%matplotlib inline# 检验拟合效果x_plotnp.linspace(0,1.2,50).reshape(-1,1)plt.plot(x_plot,model(x_plot),cr)plt.scatter(x,y)plt.show()​演示frommpl_toolkits.mplot3dimportAxes3Dimportcopyfromtqdmimporttqdmclassgrad_descent_demo:def__init__(self,optimizers:List,surface_type:strlinear,init_pos:int0): Args: optimizers: 传入需要模拟的优化器列表定义时参数设为空数组即可 surface_type: 3D表面类型可选linear和unlinear init_pos: 路径点初始位置可选0~3 self.optimizersoptimizers# 使用MSE损失self.criterionMSELoss()# 最优解self.target_w118self.target_w2-8# 100个二维数据self.Xnp.random.randn(100,2)# 线性目标值self.y(self.target_w1*self.X[:,0]self.target_w2*self.X[:,1]).reshape(-1,1)# 线性层if(surface_typelinear):modelSequential(Linear(2,1,biasFalse),)# 模拟前向传播过程self.forward_fnlambdaw1,w2:(w1*self.X[:,0]w2*self.X[:,1]).reshape(-1,1)# 线性层Tanh层elif(surface_typeunlinear):modelSequential(Linear(2,1,biasFalse),Tanh(),)self.forward_fnlambdaw1,w2:np.tanh(w1*self.X[:,0]w2*self.X[:,1]).reshape(-1,1)else:raiseValueError(Unsupported surface type!)# 每个优化器独立创建一个模型self.models[copy.deepcopy(model)foroptiminself.optimizers]# 为每个优化器关联对应的模型fori,optiminenumerate(self.optimizers):optim.paramsself.models[i].parameters()# 最优损失self.target_loss,_criterion(self.forward_fn(self.target_w1,self.target_w2),self.y)# 初始位置if(init_pos0):passelif(init_pos1):formodelinself.models:model[0].weight.datanp.array([[-5.0,7.0]])elif(init_pos2):formodelinself.models:model[0].weight.datanp.array([[-8.0,-20.0]])elif(init_pos3):formodelinself.models:model[0].weight.datanp.array([[2.0,-10.0]])else:raiseValueError(Unidentified initial position!)# 预计算损失曲面网格w1_rangenp.linspace(-15,50,25)w2_rangenp.linspace(-30,15,25)self.W1,self.W2np.meshgrid(w1_range,w2_range)self.Znp.zeros_like(self.W1)# 取出网格中的每个数据点计算损失foriinrange(self.W1.shape[0]):forjinrange(self.W1.shape[1]):y_predself.forward_fn(self.W1[i,j],self.W2[i,j])loss,_self.criterion(y_pred,self.y)self.Z[i,j]loss# 模拟阶段用于储存路径点self.w1_dict{}self.w2_dict{}self.loss_dict{}defstep(self,epoches:int60): 模拟梯度下降产生路径点 self.epochesepoches# 遍历每个模型和优化器fori,(model,optimizer)inenumerate(zip(self.models,self.optimizers)):w1_arr[]w2_arr[]loss_arr[]forepochintqdm(range(epoches)):# 训练流程y_predmodel(self.X)loss,gradself.criterion(y_pred,self.y)model.backward(grad)optimizer.step()optimizer.zero_grad()# 储存w1_arr.append(model[0].weight.data[0][0].copy())w2_arr.append(model[0].weight.data[0][1].copy())loss_arr.append(loss)# 每个索引储存一个优化器和对应模型的路径数据self.w1_dict[i]w1_arr self.w2_dict[i]w2_arr self.loss_dict[i]loss_arr# 绘制模拟图defplot(self): 绘制路径点动画 if(self.w1_dict{}):raiseValueError(Please run step() first!)# 开启交互模式plt.ion()# 创建固定图形对象figplt.figure(figsize(6,6))axfig.add_subplot(111,projection3d)# 绘制损失曲面ax.plot_surface(self.W1,self.W2,self.Z,cmapcoolwarm,alpha0.8,linewidth0)# 绘制最低点ax.scatter(self.target_w1,self.target_w2,self.target_loss,cg,s2)# 绘制文字说明ax.set_xlabel(w1)ax.set_ylabel(w2)ax.set_zlabel(MSE Loss)plt.show()# 为每个优化器创建一个路径点和路径线points[]lines[]colors[red,blue,green,orange,purple,brown,pink,gray]foridx,(w1_key,w2_key,loss_key)inenumerate(zip(self.w1_dict.keys(),self.w2_dict.keys(),self.loss_dict.keys())):# 图例说明直接获取优化器类名labelstr(self.optimizers[idx])[:-2]# 循环使用颜色列表colorcolors[idx%len(colors)]# 路径点和路径线point,ax.plot(self.w1_dict[w1_key][0],self.w2_dict[w2_key][0],self.loss_dict[loss_key][0],o,markersize4,colorcolor)line,ax.plot(self.w1_dict[w1_key][0],self.w2_dict[w2_key][0],self.loss_dict[loss_key][0],colorcolor,labellabel)# 储存points.append(point)lines.append(line)# 添加图例axplt.gca()ax.legend()# 更新运动轨迹forepochinrange(self.epoches):# 更新标题ax.set_title(fCurrent epoch:{epoch1}/{self.epoches},fontsize12)# 取出对应的路径数据和路径点forw1_key,w2_key,loss_key,point,lineinzip(self.w1_dict.keys(),self.w2_dict.keys(),self.loss_dict.keys(),points,lines):# 更新路径点point.set_xdata([self.w1_dict[w1_key][epoch]])point.set_ydata([self.w2_dict[w2_key][epoch]])point.set_3d_properties([self.loss_dict[loss_key][epoch]])# 更新路径线line.set_xdata(self.w1_dict[w1_key][:epoch])line.set_ydata(self.w2_dict[w2_key][:epoch])line.set_3d_properties(self.loss_dict[loss_key][:epoch])# 重新绘图plt.draw()# 暂停以实现动画效果plt.pause(0.05)# plt.savefig(imgstr(epoch).png)# 关闭交互模式plt.ioff()%matplotlib tk# 模拟三种优化器optimizer1SGD(params[],lr1e-1,momentum0,weight_decay0.6)optimizer2SGD(params[],lr1e-1,momentum0.92,weight_decay0)optimizer3Adam(params[],lr1,betas(0.9,0.99))# 初始化gddgrad_descent_demo([optimizer1,optimizer2,optimizer3],surface_typeunlinear,init_pos2)# 模拟gdd.step(epoches60)# 绘图gdd.plot()100%|███████████████████████████████████████████████████████████████████████████████| 60/60 [00:0000:00, 10899.49it/s] 100%|███████████████████████████████████████████████████████████████████████████████| 60/60 [00:0000:00, 13318.77it/s] 100%|███████████████████████████████████████████████████████████████████████████████| 60/60 [00:0000:00, 30009.33it/s]下一篇 - 神经网络常见层Numpy封装参考5其他层

相关文章:

神经网络常见层Numpy封装参考(4):优化器

目录前置层优化器SGD优化器Adam优化器测试演示完整代码下载 :神经网络常见层Numpy封装参考 - 常见层 前置层 - 神经网络常见层Numpy封装参考(1):损失层 - 神经网络常见层Numpy封装参考(2):线性…...

别再死磕PID了!用Python+MPC给机械臂做个‘未来视’控制器(附ROS2实战代码)

用PythonMPC为机械臂打造预测未来能力的智能控制器 机械臂控制领域正在经历一场静默革命——当大多数工程师还在用PID控制器解决90%的基础问题时,前沿实验室和科技公司早已将目光转向了更具前瞻性的控制策略。想象一下,如果你的控制器不仅能对当前误差做…...

如何快速解决Blender与3D打印机兼容问题:完整Blender3mfFormat使用指南

如何快速解决Blender与3D打印机兼容问题:完整Blender3mfFormat使用指南 【免费下载链接】Blender3mfFormat Blender add-on to import/export 3MF files 项目地址: https://gitcode.com/gh_mirrors/bl/Blender3mfFormat 您是否曾在Blender中精心设计了一个3D…...

QMCDecode终极指南:如何快速解密QQ音乐加密文件实现跨平台播放

QMCDecode终极指南:如何快速解密QQ音乐加密文件实现跨平台播放 【免费下载链接】QMCDecode QQ音乐QMC格式转换为普通格式(qmcflac转flac,qmc0,qmc3转mp3, mflac,mflac0等转flac),仅支持macOS,可自动识别到QQ音乐下载目录&#xff…...

ARGO:开源本地优先AI智能体平台部署与应用全指南

1. 项目概述:为什么我们需要一个“本地优先”的超级AI助手? 最近几年,AI助手的发展速度让人眼花缭乱。从最初的简单问答,到能联网搜索,再到能调用各种工具完成复杂任务,能力边界在不断拓宽。但一个核心问题…...

【高届数机械工程会议】第十二届机械工程、材料和自动化技术国际学术会议(MMEAT 2026)

第六届机器学习与智能系统工程国际学术会议(MLISE 2026) 2026 6th International Conference on Machine Learning and Intelligent Systems Engineering 北京航空航天大学主办 高届数机械工程会议推荐 往届检索稳定快速 会议官网: 第十二届…...

使用VS + VS Code + Cocos2d-x写游戏

Cocos2d-x是跨平台的2D游戏开发框架。 注意:必须用VS才能编译。 1 环境 1.1 Python 2.7 注意:必须下载Python2.7,3.x不行。 Python2.7下载地址,需要勾选Add python.exe to Path, 否则需要在系统环境变量Path添加Pyt…...

Advantech工业连接器国产替代方案与选型实践解析

在工业计算机与嵌入式系统领域,连接器不仅是基础互连器件,更是系统稳定运行的重要保障。Advantech 作为工业计算机行业的代表厂商,其产品广泛应用于工业自动化、智能制造、医疗设备、交通系统及物联网等领域。虽然 Advantech 本身并非传统意义…...

从 ng-content 到聚合机制,SAP UI5 里有没有 Angular 式内容投影

我每次把一个 Angular 组件的思路搬到 SAP UI5 里,最容易卡住的地方,往往不是属性绑定,也不是事件,而是这种很像 slot 的内容投放能力。Angular 官方把 ng-content 定义得非常明确,它不是一个普通的 DOM 元素,也不是组件,而是一个专门告诉框架把外部子内容渲染到哪里去的…...

SAP UI5 里到底有没有类似 Angular ng-container 的东西

我最近在把一套前端思维从 Angular 往 SAP UI5 映射的时候,最容易让人下意识去找的一个东西,就是 ng-container。这个标签很特别,平时写 Angular 模板时它经常出现,可浏览器里最后又看不到它。问题也就卡在这里,SAP UI5 里到底有没有一个几乎一模一样的角色,既能把一段内…...

把 SAP Cloud Connector 连接故障拆开看,为什么同样是连不上,卡点却可能完全不同

今天这类场景很常见,我们在 SAP HANA Cloud 里执行 CREATE REMOTE SOURCE,目标端明明已经在 Cloud Connector 里配好了虚拟主机和内部地址,结果系统还是抛出 Cannot resolve host name、Connection refused、Network unreachable,甚至 Socket closed by peer。表面上看,所…...

从 Cloud Connector 到 abapodbc,把 ABAP On-Premise Remote Source 真正搭起来

这类连接最近在很多混合架构项目里都会出现,业务数据还放在本地部署的 SAP S/4HANA 或其他 ABAP 系统里,分析、联合查询、虚拟化访问却已经放到了 SAP HANA Cloud。到了这个阶段,我们常见的诉求不是把所有数据一股脑搬到云上,而是先把访问链路打通,让 SAP HANA Cloud 以远…...

把 SAP HANA Cloud 连回机房, 创建 SAP HANA On-Premise Remote Source 的完整落地笔记

项目走到混合架构这一步时,最磨人的地方往往不是 SQL 本身,而是云上的 SAP HANA Cloud 已经准备好了,机房里的 SAP HANA On-Premise 也跑得很稳,可两边像隔着一道无形的墙。业务侧希望直接在云端做联邦查询,架构侧又不想把机房数据库直接暴露到公网,这时候,Remote Sourc…...

每日算法-线性dp、递归

1.跳台阶拓展问题(线性dp)题目:分析:第一种解法(线性dp):根据线性dp的经验可以定义状态表示为:dp[i]:跳到i级台阶总共有多少总跳法因为一次青蛙可以跳任意级台阶&#xf…...

uni-app x 中组件宽高使用百分比单位的问题

1. uni-app x 中组件宽高使用百分比单位的问题 关于 uni-app x 中组件宽高使用百分比单位的问题,建议如下: 1.1. 建议使用 flex:1 替代百分比 在 uni-app x 中,官方推荐尽量使用 px 配合 flex:1 来实现自适应布局,而非百分比单位…...

DeepSeek LeetCode 1755 最接近目标值的子序列和 public int minAbsDifference(int[] nums, int goal)

这个问题可以通过将数组分成两半并枚举所有子序列和,然后排序和二分查找来高效解决,时间复杂度为 O(2^{n/2} \cdot n)。算法思路1. 将数组 nums 分成两部分 left 和 right,长度分别为 n/2 和 n - n/2。 2. 分别枚举两部分的所有子序列&#x…...

FLUX.1-Krea-Extracted-LoRA效果对比:Krea风格在人像/产品/室内三类场景表现

FLUX.1-Krea-Extracted-LoRA效果对比:Krea风格在人像/产品/室内三类场景表现 1. 模型概述与核心价值 FLUX.1-Krea-Extracted-LoRA 是从 FLUX.1-Krea-dev 基础模型中提取的 LoRA 风格权重,专为 FLUX.1-dev 设计。这个模型通过精细的光影模拟和材质表现&…...

SVD降维技术:原理、实现与实战应用

1. 降维的本质与SVD的数学之美当你的数据集列数突破1000维时,每个数据点就像被困在千米高维空间里的蚂蚁——你明明知道这些维度里藏着规律,却根本看不清它们的全貌。这就是为什么我们需要降维技术,而奇异值分解(SVD)正…...

别再傻傻用加法器了!Verilog里这个‘分治’数1技巧,帮你省下FPGA的宝贵资源

Verilog资源优化实战:分治法高效统计二进制位中1的个数 在FPGA和ASIC设计中,资源优化从来都不是可有可无的选项。想象一下,当你面对一个需要处理大量并行数据流的项目时,每个模块节省下来的LUT(查找表)和寄…...

安全与权限管理:保障模型与数据资产的安全

008、安全与权限管理:保障模型与数据资产的安全 上周帮同事排查一个诡异的问题:微调好的7B模型在测试集上表现正常,部署到生产环境却突然“胡言乱语”。排查了三小时,最后发现是部署脚本误加载了同目录下一个旧版本的权重文件——那个文件是三个月前一次失败实验的残留。权…...

Real Anime Z开源价值解读:Z-Image底座+Real Anime Z微调的协同优势

Real Anime Z开源价值解读:Z-Image底座Real Anime Z微调的协同优势 1. 项目核心价值 Real Anime Z是一款基于阿里云通义Z-Image底座模型与Real Anime Z专属微调权重开发的高精度二次元图像生成工具。它专为真实系二次元风格优化,通过创新的技术方案解决…...

YOLOv11改进 | Neck篇 | CVPR最新低照度图像增强模块HVI改进YOLOv11(有效涨点)

一、本文介绍 本文给大家带来的最新改进机制是CVPR顶会中的一种新型颜色空间HVI机制,针对低照度图像增强任务中的红色区域断裂和暗区噪声问题。HVI通过极化映射重构色相表示,解决HSV中红色不连续问题,并引入可学习的强度塌缩机制稳定暗区几何分布。核心设计包括:1) 极坐标…...

基于STM32与互感器的智能电表远程监控系统设计(附WiFi通信与过载保护)

1. 智能电表远程监控系统设计概述 想象一下,你出差在外突然想起家里空调可能没关,或者想远程监控工厂设备的用电情况——这正是智能电表远程监控系统要解决的问题。基于STM32与互感器的设计方案,就像给传统电表装上"大脑"和"千…...

滚动即艺术|Paxgon高端创意官网:极简美学×沉浸式交互的品牌表达范本

合作背景 2026年1月,作为前端技术领域的资深探索者,武汉优联前端科技有限公司与马来西亚多元化顶级创意机构Paxgon签署合作协议,正式承担Paxgon官网升级项目的设计与开发。在数字化浪潮席卷全球的今天,品牌建设不再是单一的视觉呈…...

Strix AI 安全测试工具完整使用指南

Strix AI 安全测试工具完整使用指南 一、核心优势 Strix 是AI 驱动的开源安全测试工具,核心亮点: AI 自动识别漏洞,无需手动编写复杂测试规则 支持 Web 网站、本地代码、云端服务全场景扫描 提供命令行 终端图形界面 (TUI) 双模式 支持…...

Simulink参数设置避坑指南:get_param/set_param用错?变量和参数对象傻傻分不清?

Simulink参数设置避坑指南:get_param/set_param用错?变量和参数对象傻傻分不清? 在Simulink建模过程中,参数设置看似简单却暗藏玄机。许多工程师在尝试自动化参数配置时,常常陷入性能陷阱、变量作用域混乱或代码生成问…...

办公党必备:如何快速创建ZIP压缩包

当你需要发送一堆照片给朋友、归档项目文档,或只是想节省点硬盘空间时,ZIP压缩就是最好的选择。作为最通用的压缩格式,ZIP几乎能在所有设备上直接打开,而且操作十分简单。下面小编分享两种方法,让你可以快速创建ZIP压缩…...

元器件特性-二/三极管

1.二极管介绍 二极管是用半导体材料 (硅、硒、锗等)制成的一种电子元器件。 它具有单向导电性能特性 (具有正向特性和反向特性),即给二极管阳极和阴极加上正向电压时,二极管导通。 当给阳极和阴极加上反向电压时,二极管截止。 因此&#xff…...

研发leader如何增强自身在外部就业市场的竞争力

“在公司的价值”和“在市场的价值”并不完全等同。 公司可能因为业务收缩、政治变化或战略调整而“不需要你”,但这不代表你没有市场价值。你现在要做的,不是只服务于当前公司,而是在日常工作中同步为自己积累“可迁移的资产”。 下面是一个研发Leader可以持续准备的五个核…...

手把手教你用Debian Live OS救活CentOS 8:GLIBC升级翻车后的机房急救实录

深夜机房的生死时速:用Debian Live OS拯救GLIBC升级崩溃的CentOS 8服务器 凌晨2:17,刺耳的告警铃声划破寂静。监控系统显示,核心业务服务器突然离线。当我远程连接时,SSH会话在输入密码后立即断开——这是典型的GLIBC版本冲突症状…...