当前位置: 首页 > article >正文

2-1 MATLAB鮣鱼优化算法ROA优化LSTM超参数回归预测

本博客来源于CSDN机器鱼,未同意任何人转载。

更多内容,欢迎点击本专栏目录,查看更多内容。

目录

0.ROA原理

1.LSTM程序

2.ROA优化LSTM

3.主程序

4.结语


0.ROA原理

具体原理看原文,但是今天咱不用知道具体原理,只需要找到源码,然后改成优化LSTM的即可。下面是我从网上找到的源码。ROA是主要的代码,Cost是适应度函数,这个代码的是找Cost的最小值。

function [Fbest, Rbest,Convergence_curve]= ROA()
sizepop=30; % Number of search agents
maxgen=500; % Maximum number of iterations
lb=-100;
ub=100;
D=30;%maxgen为最大迭代次数,
%sizepop为种群规模
%记D为维度,lb、 ub分别为搜索上、下限
R=ones(sizepop,D);%预设种群
for i= 1:DR(:,i)=lb + (ub-lb)*rand(sizepop,1);
end
for k= 1:sizepopFitness(k)=Cost(R(k,:));%个体适应度
end
[Fbest,elite]= min(Fitness);%Fbest为最优适应度值
Rbest= R(elite,:);%最优个体位置
H=zeros(1,sizepop);%控制因子
ub=ones(1,D)*ub;
lb=ones(1,D)*lb;%主循环
for iter= 1:maxgenRpre= R;%记录上一代的位置V=2*(1-iter/maxgen);B= 2*V*rand-V;a=-(1 + iter/maxgen);alpha=rand*(a-1)+ 1;for i= 1:sizepopif H(i)==0dis = abs(Rbest-R(i,:));R(i,:)= R(i,:)+ dis* exp(alpha)*cos(2*pi* alpha);elseRAND= ceil(rand*sizepop);%随机选择一个个体R(i,:)= Rbest -(rand*0.5*(Rbest + R(RAND,:))- R(RAND,:));endRatt= R(i,:)+ (R(i,:)- Rpre(i,:))*randn;%作出小幅度移动%边界吸收for k=1:DFlag4ub= R(i,k)>ub(k);Flag4lb= R(i,k)<lb(k);R(i,k)=(R(i,k).*(~(Flag4ub + Flag4lb))) + ub(k).*Flag4ub + lb(k).*Flag4lb;Flag4ub= Ratt(1,k)> ub(k);Flag4lb= Ratt(1,k)<lb(k);Ratt(1,k)=(Ratt(1,k).*(~(Flag4ub + Flag4lb)))+ ub(k).*Flag4ub + lb(k).*Flag4lb;endFitness(i)=Cost(R(i,:));Fitness_Ratt= Cost(Ratt);if Fitness_Ratt < Fitness(i)%改变寄主if H(i)==1H(i)=0;elseH(i)=1;endelse %不改变寄主A= B*(R(i,:)-rand*0.3*Rbest);R(i,:)=R(i,:)+A;end%边界吸收for k=1:DFlag4ub= R(i,k)>ub(k);Flag4lb= R(i,k)<lb(k);R(i,k)=(R(i,k).*(~(Flag4ub+ Flag4lb))) + ub(k).*Flag4ub + lb(k).*Flag4lb;endend%更新适应度值、位置[fbest,elite] = min(Fitness);%更新最优个体if fbest< FbestFbest= fbest;Rbest= R(elite,:);endConvergence_curve(iter)= Fbest;
end
endfunction o = Cost(x)
o=sum(x.^2);
end

调用这个代码的主程序如下:

clear ;close all;clc;format compact
%鮣鱼优化算法(Remora Optimization Algorithm)
[BestF,BestP,Convergence_curve1]=ROA();
figure
semilogy(Convergence_curve1)

1.LSTM程序

首先建立一个LSTM网络,这次我们是做回归任务,数据是3输入1输出,构建一个含2个lstmlayer的LSTM网络,代码如下:

%% LSTM时间序列预测
clc;clear;close all
%%
load data
XTrain;%3*97
XTest;%3*68
YTrain;%1*97
YTest;%1*68
%% 参数设置
train=0;%为1就重新训练,否则加载训练好的模型进行预测
if train==1rng(0)numFeatures = size(XTrain,1);%输入节点数numResponses = size(YTrain,1);%输出节点数miniBatchSize = 16; %batchsizenumHiddenUnits1 = 20;numHiddenUnits2 = 20;maxEpochs=100;learning_rate=0.005;layers = [ ...sequenceInputLayer(numFeatures)lstmLayer(numHiddenUnits1)lstmLayer(numHiddenUnits2)fullyConnectedLayer(numResponses)regressionLayer];options = trainingOptions('adam', ...'ExecutionEnvironment', 'cpu',...'MaxEpochs',maxEpochs, ...'MiniBatchSize',miniBatchSize, ...'InitialLearnRate',learning_rate, ...'GradientThreshold',1, ...'Shuffle','every-epoch', ...'Verbose',true,...'Plots','training-progress');net = trainNetwork(XTrain,YTrain,layers,options);save model/lstm net
elseload model/lstm
end
YPred = predict(net,XTest,'ExecutionEnvironment', 'cpu');YPred=double(YPred);

从构建网络那里,我们发现,构建一个超级简单的LSTM,依旧有miniBatchSize ,numHiddenUnits1 ,numHiddenUnits2 ,maxEpochs,learning_rate共5个超参数需要设置,而网络越复杂,那需要优化的超参数也就更多,手动选择就算了,一般是选不出来,为此这篇博客采用ROA进行优化。

2.ROA优化LSTM

任意一个优化网路超参数的步骤都是通用的,步骤如下:

步骤1:知道要优化的参数的优化范围。显然就是上面提到的5个参数。代码如下,首先改写lb与ub,然后初始化的时候注意除了学习率,其他的都是整数。并将原来里面的边界判断,改成了Bounds函数,方便在计算适应度函数前转化成整数与小数。

function [Rbest,Convergence_curve,process]= ROAforlstm(X1,y1,Xt,yt)
D=5;
sizepop=5;%种群数量
maxgen=10;%寻优代数
%范围
lb=[1 1   1   1  0.001];%分别对batchsize、两个lstm隐含层节点 训练次数与学习率寻优
ub=[64 100 100 50  0.01];%这个分别代表5个参数的上下界,比如第一个参数的范围就是1-64%
%maxgen为最大迭代次数,
%sizepop为种群规模
%记D为维度,lb、 ub分别为搜索上、下限
R=ones(sizepop,D);%预设种群
for i=1:sizepop%随机初始化速度,随机初始化位置for j=1:Dif j==D%除了学习率 其他的都是整数R( i, j ) = (ub(j)-lb(j))*rand+lb(j);elseR( i, j ) = round((ub(j)-lb(j))*rand+lb(j));endend
endfor k= 1:sizepopFitness(k)=fitness(R(k,:),X1,y1,Xt,yt);%个体适应度
end
[Fbest,elite]= min(Fitness);%Fbest为最优适应度值
Rbest= R(elite,:);%最优个体位置
H=zeros(1,sizepop);%控制因子%主循环
for iter= 1:maxgenRpre= R;%记录上一代的位置V=2*(1-iter/maxgen);B= 2*V*rand-V;a=-(1 + iter/maxgen);alpha=rand*(a-1)+ 1;for i= 1:sizepopif H(i)==0dis = abs(Rbest-R(i,:));R(i,:)= R(i,:)+ dis* exp(alpha)*cos(2*pi* alpha);elseRAND= ceil(rand*sizepop);%随机选择一个个体R(i,:)= Rbest -(rand*0.5*(Rbest + R(RAND,:))- R(RAND,:));endRatt= R(i,:)+ (R(i,:)- Rpre(i,:))*randn;%作出小幅度移动%边界吸收R(i, : ) = Bounds( R(i, : ), lb, ub );%对超过边界的变量进行去除Ratt = Bounds( Ratt, lb, ub );%对超过边界的变量进行去除Fitness(i)=fitness(R(i,:),X1,y1,Xt,yt);Fitness_Ratt= fitness(Ratt,X1,y1,Xt,yt);if Fitness_Ratt < Fitness(i)%改变寄主if H(i)==1H(i)=0;elseH(i)=1;endelse %不改变寄主A= B*(R(i,:)-rand*0.3*Rbest);R(i,:)=R(i,:)+A;endR(i, : ) = Bounds( R(i, : ), lb, ub );%对超过边界的变量进行去除end%更新适应度值、位置[fbest,elite] = min(Fitness);%更新最优个体if fbest< FbestFbest= fbest;Rbest= R(elite,:);endprocess(iter,:)=Rbest;Convergence_curve(iter)= Fbest;iter,Fbest,Rbest
endendfunction s = Bounds( s, Lb, Ub)
temp = s;
dim=length(Lb);
for i=1:length(s)if i==dim%除了学习率 其他的都是整数temp(:,i) =temp(:,i);elsetemp(:,i) =round(temp(:,i));end
end% 判断参数是否超出设定的范围for i=1:length(s)if temp(:,i)>Ub(i) | temp(:,i)<Lb(i) if i==dim%除了学习率 其他的都是整数temp(:,i) =rand*(Ub(i)-Lb(i))+Lb(i);elsetemp(:,i) =round(rand*(Ub(i)-Lb(i))+Lb(i));endend
end
s = temp;
end

步骤2:知道优化的目标。优化的目标是提高的网络的准确率,而ROA代码我们这个代码是最小值优化的,所以我们的目标可以是最小化LSTM的预测误差。预测误差具体是,测试集(或验证集)的预测值与真实值之间的均方差。

步骤3:构建适应度函数。通过步骤2我们已经知道目标,即采用ROA去找到5个值,用这5个值构建的网络,误差最小化。观察下面的代码,首先我们将ROA的值传进来,然后转成需要的5个值,然后构建网络,训练集训练、测试集预测,计算预测值与真实值的mse,将mse作为结果传出去作为适应度值。

function y=fitness(x,p,t,pt,tt)
rng(0)
numFeatures = size(p,1);%输入节点数
numResponses = size(t,1);%输出节点数
miniBatchSize = x(1); %batchsize
numHiddenUnits1 = x(2);
numHiddenUnits2 = x(3);
maxEpochs=x(4);
learning_rate=x(5);
layers = [ ...sequenceInputLayer(numFeatures)lstmLayer(numHiddenUnits1)lstmLayer(numHiddenUnits2)fullyConnectedLayer(numResponses)regressionLayer];
options = trainingOptions('adam', ...'ExecutionEnvironment', 'cpu',...'MaxEpochs',maxEpochs, ...'MiniBatchSize',miniBatchSize, ...'InitialLearnRate',learning_rate, ...'GradientThreshold',1, ...'Shuffle','every-epoch', ...'Verbose',false);net = trainNetwork(p,t,layers,options);YPred = predict(net,pt,'ExecutionEnvironment', 'cpu');YPred=double(YPred);
[m,n]=size(YPred);
YPred=reshape(YPred,[1,m*n]);
tt=reshape(tt,[1,m*n]);y =mse(YPred-tt);
% 以mse为适应度函数,优化算法目的就是找到一组超参数 使网络的mse最低
rng((100*sum(clock)))

3.主程序

%% ROA优化LSTM时间序列预测
clc;clear;close all;format compact
%%
load data%% 采用ROA优化
optimization=1;%是否重新优化
if optimization==1[x ,fit_gen,process]=ROAforlstm(XTrain,YTrain,XTest,YTest);%分别对batchsize 隐含层节点 训练次数与学习率寻优save result/ROA_para_result x fit_gen process
elseload result/ROA_para_result
end
%% 利用优化得到的参数重新训练,得到预测值
train=1;%是否重新训练
if train==1rng(0)numFeatures = size(XTrain,1);%输入节点数numResponses = size(YTrain,1);%输出节点数miniBatchSize = x(1); %batchsizenumHiddenUnits1 = x(2);numHiddenUnits2 = x(3);maxEpochs=x(4);learning_rate=x(5);layers = [ ...sequenceInputLayer(numFeatures)lstmLayer(numHiddenUnits1)lstmLayer(numHiddenUnits2)fullyConnectedLayer(numResponses)regressionLayer];options = trainingOptions('adam', ...'ExecutionEnvironment', 'cpu',...'MaxEpochs',maxEpochs, ...'MiniBatchSize',miniBatchSize, ...'InitialLearnRate',learning_rate, ...'GradientThreshold',1, ...'Shuffle','every-epoch', ...'Verbose',true,...'Plots','training-progress');net = trainNetwork(XTrain,YTrain,layers,options);save model/ROAlstm net
elseload model/ROAlstm
end
% 预测
YPred = predict(net,XTest,'ExecutionEnvironment', 'cpu');
YPred=double(YPred);

4.结语

优化网络超参数的格式都是这样的!只要会改一种,那么随便拿一份能跑通的优化算法,在不管原理的情况下,都能用来优化网络的超参数。晚一点我们再来写一个简单的CNN,并用这个算法来优化。更多内容【点击专栏】目录。

相关文章:

2-1 MATLAB鮣鱼优化算法ROA优化LSTM超参数回归预测

本博客来源于CSDN机器鱼&#xff0c;未同意任何人转载。 更多内容&#xff0c;欢迎点击本专栏目录&#xff0c;查看更多内容。 目录 0.ROA原理 1.LSTM程序 2.ROA优化LSTM 3.主程序 4.结语 0.ROA原理 具体原理看原文&#xff0c;但是今天咱不用知道具体原理&#xff0c;只…...

fircrawl本地部署

企业内部的网站作为知识库给dify使用&#xff0c;使用fircrawl来爬虫并且转换为markdown。 ​ git clone https://github.com/mendableai/firecrawl.gitcd ./firecrawl/apps/api/ cp .env.example .env cd ~/firecrawl docker compose up -d 官方&#xff1a; https://githu…...

Labview学习记录

1.快捷键 ctrlR 运行 ctrlB 去除断线 ctrlH 即时帮助 ctrlE 前后面板切换 2.画面移动 ctrlshift鼠标左键...

【Golang】第八弹----面向对象编程

&#x1f525; 个人主页&#xff1a;星云爱编程 &#x1f525; 所属专栏&#xff1a;Golang &#x1f337;追光的人&#xff0c;终会万丈光芒 &#x1f389;欢迎大家点赞&#x1f44d;评论&#x1f4dd;收藏⭐文章 前言&#xff1a;Go语言面向对象编程说明 Golang也支持面向对…...

java基础以及内存图

java基础 命名&#xff1a; 大驼峰&#xff1a;类名 小驼峰&#xff1a;变量名方法名等其他的 全部大写&#xff1a;常量名字.. // 单行注释 /**/ 多行注释 变量类型 变量名 一、基本类型&#xff08;8个&#xff09; 整数&#xff1a;byte-8bit short-16bit int 32-b…...

【嵌入式学习3】TCP服务器客户端 - UDP发送端接收端

目录 1、TCP TCP特点 TCP三次握手&#xff08;建立TCP连接&#xff09;&#xff1a; TCP四次握手【TCP断开链接的时候需要经过4次确认】&#xff1a; TCP网络程序开发流程 客户端开发&#xff1a;用户设备上的程序 服务器开发&#xff1a;服务器设备上的程序 2、UDP 为…...

Linux之基础知识

目录 一、环境准备 1.1、常规登录 1.2、免密登录 二、Linux基本指令 2.1、ls命令 2.2、pwd命令 2.3、cd命令 2.4、touch命令 2.5、mkdir命令 2.6、rmdir和rm命令 2.7man命令 2.8、cp命令 2.9、mv命令 2.10、cat命令 2.11、echo命令 2.11.1、Ctrl r 快捷键 2…...

llamafactory微调效果与vllm部署效果不一致如何解决

在llamafactory框架训练好模型之后&#xff0c;自测chat时模型效果不错&#xff0c;但是部署到vllm模型上效果却很差 这实际上是因为llamafactory微调时与vllm部署时的对话模板不一致导致的。 对应的llamafactory的代码为 而vllm启动时会采用大模型自己本身设置的对话模板信息…...

Python控制结构详解

前言 一、控制结构概述 二、顺序结构 三、选择结构&#xff08;分支结构&#xff09; 1. 单分支 if 2. 双分支 if-else 3. 多分支 if-elif-else 4.实际应用: 四、循环结构 1. for循环 2. while循环 3. 循环控制语句 五、异常处理&#xff08;try-except&#xff09…...

Mysql-经典实战案例(11):深度解析Sysbench压测(从入门到MySQL服务器性能验证)

引言 如何用Sysbench压测满足mysql生产运行的服务器&#xff1f; Sysbench返回的压测结果如何解读&#xff1f; 别急&#xff0c;本文会教大家如何使用并且如何解读压测的结果信息&#xff0c;如何对mysql服务器进行压测&#xff01; 一、Sysbench核心功能全景解析 1.1 工…...

WebSocket通信的握手阶段

1. 客户端建立连接时&#xff0c;通过 http 发起请求报文&#xff0c;报文表示请求服务器端升级协议为 WebSocket&#xff0c;与普通的 http 请求协议略有区别的部分在于如下的这些协议头&#xff1a; 上述两个字段表示请求服务器端升级协议为 websocket 协议。 2. 服务器端响…...

分布式ID服务实现全面解析

分布式ID生成器是分布式系统中的关键基础设施&#xff0c;用于在分布式环境下生成全局唯一的标识符。以下是各种实现方案的深度解析和最佳实践。 一、核心需求与设计考量 1. 核心需求矩阵 需求 重要性 实现难点 全局唯一 必须保证 时钟回拨/节点冲突 高性能 高并发场景…...

dom0运行android_kernel: do_serror of panic----failed to stop secondary CPUs 0

问题描述&#xff1a; 从日志看出,dom0运行android_kernel&#xff0c;刚开始运行就会crash,引发panic 解决及其原因分析&#xff1a; 最终问题得到解决&#xff0c;发现是前期在调试汇编阶段代码时&#xff0c;增加了汇编打印的指令&#xff0c;注释掉这些指令,问题得到解决。…...

HarmonyOS NEXT——【鸿蒙原生应用加载Web页面】

鸿蒙客户端加载Web页面&#xff1a; 在鸿蒙原生应用中&#xff0c;我们需要使用前端页面做混合开发&#xff0c;方法之一是使用Web组件直接加载前端页面&#xff0c;其中WebView提供了一系列相关的方法适配鸿蒙原生与web之间的使用。 效果 web页面展示&#xff1a; Column()…...

HTML输出流

HTML 输出流 JavaScript 中**「直接写入 HTML 输出流」**的核心是通过 document.write() 方法向浏览器渲染过程中的数据流动态插入内容。以下是详细解释&#xff1a; 一、HTML 输出流的概念 1. 动态渲染过程 HTML 文档的加载是自上而下逐行解析的。当浏览器遇到 <script&…...

std::countr_zero

一 基本功能 1 作用 std::countr_zero 是 C++20 标准引入的位操作函数,用于计算无符号整数的二进制表示中末尾零(Trailing Zeros)的数量。 定义:位于 <bit> 头文件中,是标准库的一部分。 2 示例 #include <bit> unsigned int x = 12; // 二进…...

优选算法的慧根之翼:位运算专题

专栏&#xff1a;算法的魔法世界 个人主页&#xff1a;手握风云 一、位运算 基础位运算 共包含6种&(按位与&#xff0c;有0就是0)、|(按位或有1就是1)、^(按位异或&#xff0c;相同为0&#xff0c;相异为1)、~(按位取反&#xff0c;0变成1&#xff0c;1变成0)、<<(左…...

图论问题集合

图论问题集合 寻找特殊有向图&#xff08;一个节点最多有一个出边&#xff09;中最大环路问题特殊有向图解析算法解析步骤 1 &#xff1a;举例分析如何在一个连通块中找到环并使用时间戳计算大小步骤 2 &#xff1a;抽象成算法注意 实现 寻找特殊有向图&#xff08;一个节点最多…...

【数据结构】栈 与【LeetCode】20.有效的括号详解

目录 一、栈1、栈的概念及结构2、栈的实现3、初始化栈和销毁栈4、打印栈的数据5、入栈操作---栈顶6、出栈---栈顶6.1栈是否为空6.2出栈---栈顶 7、取栈顶元素8、获取栈中有效的元素个数 二、栈的相关练习1、练习2、AC代码 个人主页&#xff0c;点这里~ 数据结构专栏&#xff0c…...

实时目标检测新突破:AnytimeYOLO——随时中断的YOLO优化框架解析

目录 一、论文背景与核心价值 二、创新技术解析 2.1 网络结构革新:Transposed架构 2.2 动态路径优化算法 三、实验结果与性能对比 3.1 主要性能指标 3.2 关键发现 四、应用场景与部署实践 4.1 典型应用场景 4.2 部署注意事项 五、未来展望与挑战 一、论文背景与核心…...

Redis设计与实现-哨兵

哨兵模式 1、启动并初始化sentinel1.1 初始化服务器1.2 使用Sentinel代码1.3 初始化sentinel状态1.4 初始化sentinel状态的master属性1.5 创建连向主服务器的网络连接 2、获取主服务器信息3、获取从服务器的信息4、向主从服务器发送信息5、接受主从服务器的频道信息6、检测主观…...

C++进阶——封装哈希表实现unordered_map/set

与红黑树封装map/set基本相似&#xff0c;只是unordered_map/set是单向迭代器&#xff0c;模板多传一个HashFunc。 目录 1、源码及框架分析 2、模拟实现unordered_map/set 2.1 复用的哈希表框架及Insert 2.2 iterator的实现 2.2.1 iteartor的核心源码 2.2.2 iterator的实…...

第4.1节:使用正则表达式

1 第4.1节&#xff1a;使用正则表达式 将正则表达式用斜杠括起来&#xff0c;就能用作模式。随后&#xff0c;该正则表达式会与每条输入记录的完整文本进行比对。&#xff08;通常情况下&#xff0c;它只需匹配文本的部分内容就能视作匹配成功。&#xff09;例如&#xff0c;以…...

【算法day25】 最长有效括号——给你一个只包含 ‘(‘ 和 ‘)‘ 的字符串,找出最长有效(格式正确且连续)括号子串的长度。

32. 最长有效括号 给你一个只包含 ‘(’ 和 ‘)’ 的字符串&#xff0c;找出最长有效&#xff08;格式正确且连续&#xff09;括号子串的长度。 https://leetcode.cn/problems/longest-valid-parentheses/ 2.方法二&#xff1a;栈 class Solution { public:int longestValid…...

Jenkins + CICD流程一键自动部署Vue前端项目(保姆级)

git仓库地址&#xff1a;参考以下代码完成,或者采用自己的代码。 南泽/cicd-test 拉取项目代码到本地 使用云服务器或虚拟机采用docker部署jenkins 安装docker过程省略 采用docker部署jenkins&#xff0c;注意这里的命令&#xff0c;一定要映射docker路径&#xff0c;否则无…...

C 语言的未来:在变革中坚守核心价值

一、从 “古老” 到 “长青”&#xff1a;C 语言的不可替代性 诞生于 20 世纪 70 年代的 C 语言&#xff0c;历经半个世纪的技术浪潮&#xff0c;至今仍是编程世界的 “基石语言”。尽管 Python、Java 等高级语言在应用层开发中占据主流&#xff0c;但 C 语言在系统级编程和资…...

一款超级好用且开源免费的数据可视化工具——Superset

认识Superset 数字经济、数字化转型、大数据等等依旧是如今火热的领域&#xff0c;数据工作有一个重要的环节就是数据可视化。 看得见的数据才更有价值&#xff01; 现如今依旧有多数企业号称有多少多少数据&#xff0c;然而如果这些数据只是呆在冷冰冰的数据库或文件内则毫无…...

Vue3组合式API与选项式API的核心区别与适用场景

Vue.js作为现代前端开发的主流框架之一&#xff0c;在Vue3中引入了全新的组合式API(Composition API)&#xff0c;与传统的选项式API(Options API)形成了两种不同的开发范式。在当前开发中的两个项目中分别用到了组合式和选项式&#xff0c;故记录一下。本文将全面剖析这两种AP…...

RedHatLinux(2025.3.22)

1、创建/www目录&#xff0c;在/www目录下新建name和https目录&#xff0c;在name和https目录下分别创建一个index.htm1文件&#xff0c;name下面的index.html 文件中包含当前主机的主机名&#xff0c;https目录下的index.htm1文件中包含当前主机的ip地址。 &#xff08;1&…...

【C++篇】类与对象(上篇):从面向过程到面向对象的跨越

&#x1f4ac; 欢迎讨论&#xff1a;在阅读过程中有任何疑问&#xff0c;欢迎在评论区留言&#xff0c;我们一起交流学习&#xff01; &#x1f44d; 点赞、收藏与分享&#xff1a;如果你觉得这篇文章对你有帮助&#xff0c;记得点赞、收藏&#xff0c;并分享给更多对C感兴趣的…...