当前位置: 首页 > news >正文

【优化算法】使用遗传算法优化MLP神经网络参数(TensorFlow2)

文章目录

  • 任务
  • 查看当前的准确率情况
  • 使用遗传算法进行优化
  • 完整代码

任务

使用启发式优化算法遗传算法对多层感知机中中间层神经个数进行优化,以提高模型的准确率。

待优化的模型:
基于TensorFlow2实现的Mnist手写数字识别多层感知机MLP

# MLP手写数字识别模型,待优化的参数为layer1、layer2
model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(layer1, activation='relu'),tf.keras.layers.Dense(layer2, activation='relu'),tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字
])

查看当前的准确率情况

设置随机树种子,避免相同结构的神经网络其结果不同的影响。

random_seed = 10
np.random.seed(random_seed)
tf.random.set_seed(random_seed)
random.seed(random_seed)
model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(32, activation='relu'),tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字
])
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer,loss=loss_func,metrics=['acc'])
history = model.fit(dataset, validation_data=test_dataset, epochs=5, verbose=1)
score = model.evaluate(test_dataset)
print("测试集准确率:",score) # 输出 [损失率,准确率]
Epoch 1/5
235/235 [==============================] - 1s 5ms/step - loss: 0.4663 - acc: 0.8703 - val_loss: 0.2173 - val_acc: 0.9361
Epoch 2/5
235/235 [==============================] - 1s 4ms/step - loss: 0.1882 - acc: 0.9465 - val_loss: 0.1604 - val_acc: 0.9521
Epoch 3/5
235/235 [==============================] - 1s 4ms/step - loss: 0.1362 - acc: 0.9608 - val_loss: 0.1278 - val_acc: 0.9629
Epoch 4/5
235/235 [==============================] - 1s 4ms/step - loss: 0.1084 - acc: 0.9689 - val_loss: 0.1086 - val_acc: 0.9681
Epoch 5/5
235/235 [==============================] - 1s 4ms/step - loss: 0.0878 - acc: 0.9740 - val_loss: 0.1077 - val_acc: 0.9675
40/40 [==============================] - 0s 2ms/step - loss: 0.1077 - acc: 0.9675
测试集准确率: [0.10773944109678268, 0.9674999713897705]

准确率为96.7%

使用遗传算法进行优化

使用scikit-opt提供的遗传算法库进行优化。(pip install scikit-opt
官方文档:https://scikit-opt.github.io/scikit-opt/#/zh/README

# 遗传算法调用
ga = GA(func=loss_func, n_dim=2, size_pop=4, max_iter=3, prob_mut=0.15, lb=[10, 10], ub=[256, 256], precision=1)

优化目标函数loss_func:1 - 模型准确率(求优化目标函数的最小值
待优化的参数数量n_dim:2
种群数量size_pop:4
迭代次数max_iter:3
变异概率prob_mut:0.15
待优化两个参数的取值范围均为10-256
精确度precision:1

# 定义多层感知机模型函数
def mlp_model(layer1, layer2):model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(layer1, activation='relu'),tf.keras.layers.Dense(layer2, activation='relu'),tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字])optimizer = tf.keras.optimizers.Adam()loss_func = tf.keras.losses.SparseCategoricalCrossentropy()model.compile(optimizer=optimizer,loss=loss_func,metrics=['acc'])return model# 定义损失函数
def loss_func(params):random_seed = 10np.random.seed(random_seed)tf.random.set_seed(random_seed)random.seed(random_seed)layer1 = int(params[0])layer2 = int(params[1])model = mlp_model(layer1, layer2)history = model.fit(dataset, validation_data=test_dataset, epochs=5, verbose=0)return 1 - history.history['val_acc'][-1]ga = GA(func=loss_func, n_dim=2, size_pop=4, max_iter=3, prob_mut=0.15, lb=[10, 10], ub=[256, 256], precision=1)
bestx, besty = ga.run()
print('bestx:', bestx, '\n', 'besty:', besty)

结果:

bestx: [165. 155.] 
besty: [0.02310002]

通过迭代,找到layer1、layer2的最好值为165、155,此时准确率为1-0.0231=0.9769。

查看每轮迭代的损失函数值图(1-准确率):

Y_history = pd.DataFrame(ga.all_history_Y)
fig, ax = plt.subplots(2, 1)
ax[0].plot(Y_history.index, Y_history.values, '.', color='red')
Y_history.min(axis=1).cummin().plot(kind='line')
plt.show()

在这里插入图片描述
上图为三次迭代种群中,种群每个个体的损失函数值(每个种群4个个体)。
下图为三次迭代种群中,种群个体中的最佳损失函数值。

可以看出,通过遗传算法,其准确率有一定的提升。

增加种群数量及迭代次数效果可更加明显。

完整代码

# python3.6.9
import tensorflow as tf # 2.3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from sko.GA import GA# 数据导入,获取训练集和测试集
(train_image, train_labels), (test_image, test_labels) = tf.keras.datasets.mnist.load_data()# 增加通道维度
train_image = tf.expand_dims(train_image, -1)
test_image = tf.expand_dims(test_image, -1)# 归一化 类型转换
train_image = tf.cast(train_image/255, tf.float32)
test_image = tf.cast(test_image/255, tf.float32)
train_labels = tf.cast(train_labels, tf.int64)
test_labels = tf.cast(test_labels, tf.int64)# 创建Dataset
dataset = tf.data.Dataset.from_tensor_slices((train_image, train_labels)).shuffle(60000).batch(256)
test_dataset = tf.data.Dataset.from_tensor_slices((test_image, test_labels)).batch(256)# 定义多层感知机模型函数
def mlp_model(layer1, layer2):model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28, 1)), tf.keras.layers.Dense(layer1, activation='relu'),tf.keras.layers.Dense(layer2, activation='relu'),tf.keras.layers.Dense(10,activation='softmax') # 对应0-9这10个数字])optimizer = tf.keras.optimizers.Adam()loss_func = tf.keras.losses.SparseCategoricalCrossentropy()model.compile(optimizer=optimizer,loss=loss_func,metrics=['acc'])return model# 定义损失函数
def loss_func(params):random_seed = 10np.random.seed(random_seed)tf.random.set_seed(random_seed)random.seed(random_seed)layer1 = int(params[0])layer2 = int(params[1])model = mlp_model(layer1, layer2)history = model.fit(dataset, validation_data=test_dataset, epochs=5, verbose=0)return 1 - history.history['val_acc'][-1]ga = GA(func=loss_func, n_dim=2, size_pop=4, max_iter=3, prob_mut=0.15, lb=[10, 10], ub=[256, 256], precision=1)
bestx, besty = ga.run()
print('bestx:', bestx, '\n', 'besty:', besty)

相关文章:

【优化算法】使用遗传算法优化MLP神经网络参数(TensorFlow2)

文章目录任务查看当前的准确率情况使用遗传算法进行优化完整代码任务 使用启发式优化算法遗传算法对多层感知机中中间层神经个数进行优化,以提高模型的准确率。 待优化的模型: 基于TensorFlow2实现的Mnist手写数字识别多层感知机MLP # MLP手写数字识别…...

CAM类激活映射 |神经网络可视化 | 热力图

文章目录前言:安装库:分类案例--ResNet50分割案例AttributeError: ‘tuple‘ object has no attribute ‘cpu‘RuntimeError: grad can be implicitly created only for scalar outputsTypeError: cant convert cuda:0 device type tensor to numpy. Use…...

RecyclerView+BaseRecyclerViewAdapterHelper显示不全只显示第一行item的解决问题

RecyclerViewBaseRecyclerViewAdapterHelper显示不全只显示第一行item,我懵了…,我不说多,直接说吧 先看一下适配器代码中的convert()方法: class MineRadioAdapter(layoutResId: Int R.layout.item_my_live) :BaseQuickAdapte…...

解决后端无法对前端的ajax请求重定向

本章目录: 问题描述 AJAX请求后端直接重定向失败解决方案 后端拦截请为响应头添加重定向标志后端拦截器为响应头添加重定向路径前端响应拦截器获取响应头数据,并通过location.href url 完成页面跳转一、问题描述 本来想在拦截器里设置未登录用户访问指…...

【Python】1分钟就能制作精美的框架图?太棒啦

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录前言一、准备二、基本使用与例子1.初始化与导出2.节点类型3.集群块4.自定义线的颜色与属性总结前言 Diagrams 是一个基于Python绘制云系统架构的模块,它能…...

淘宝必备的补单技巧及注意事项!

补单,是优化善后的s单。单只是模拟用户的购物习惯,而补单同时还要模拟整个店铺的综合数据,包括点击率、转化率等等,补到略高于同行、竞品的平均数据时,淘宝会判断为买家比较喜欢你的商品,从而给你更多推荐机…...

【实用篇】SpringCloud+RabbitMQ+Docker+Redis+搜索+分布式,系统详解springcloud分布式

文章目录一、服务拆分1.1 服务拆分Demo1.2 微服务远程调用二、Eureka2.1 Eureka原理2.2 Eureka-server服务搭建2.3 eureka-client服务注册2.4 eureka-client服务复制2.5 eureka服务发现三、Ribbon负载均衡3.1 负载均衡原理3.2 负载均衡策略3.3 自定义负载均衡策略3.4 饥饿加载与…...

私人飞机、公务机包机会成为富豪圈的主流出行方式吗?

从炫耀性消费到按需使用,私人飞机的消费群体正在被拓宽,但离“成为主流”还有一段距离。“时间就是金钱”为有钱人消费私人飞机提供合理动机,而这群高净值人群的数量增长则成为撑起市场基本面。据相关数据显示,2018年全球超级富豪…...

Oracle组织架构

组织架构 (一)业务组(BG) (二)法律实体(LE) (三)业务实体(OU) (四)库存组织(INV) …...

最小公倍数

目录 最小公倍数 程序设计 程序分析 最小公倍数 【问题描述】给定两个正整数,计算这两个数的最小公倍数。 【输入形式】输入包含多组测试数据,每组只有一行,包括两个不大于1000的正整数. 【输出形式】 对于每个测试用例,给出这两个数的最小公倍数,每个实例输出一行。…...

二叉树的后序遍历(力扣145)

目录 题目描述: 解法一:递归法 解法二:迭代法 解法三:Morris遍历 二叉树的后序遍历 题目描述: 给你一棵二叉树的根节点 root ,返回其节点值的 后序遍历 。 示例 1: 输入:root …...

《Effective C++》读书纪实 -- 诸君同享

文章目录《Effective C》是一本经典的C编程指南,共包含50条C编程的最佳实践。 确定你的构造函数的行为 在构造函数中,应该尽可能地避免调用虚函数、非静态成员函数和虚基类的函数。 尽量使用const、enum、inline替换#define 使用const、enum、inline可以…...

【云原生】K8S-ConfigMap 实现应用和配置分离

文章目录前言ConfigMap 背景ConfigMap 创建方式ConfigMap 的使用使用 ConfigMap 的注意事项总结前言 Kubernetes 是目前最流行的容器编排系统之一,它提供了丰富的功能来支持容器化应用程序的管理和部署。 ConfigMap 是 Kubernetes 中重要的资源对象,用…...

java -测距工具(经纬度)

代码 /*** 测距工具* author qb*/ public class DistanceUtils {/*** 赤道半径*/private static final double EARTH_RADIUS 6378.137;private static double rad(double d) {return d * Math.PI / 180.0;}/*** Description : 通过经纬度获取距离(单位:米)* Group…...

postgres分区表的创建-基于继承

参考文档: http://postgres.cn/docs/12/ddl-partitioning.html 创建基于继承的分区表的步骤 1 创建父表 2 创建子表,从父表继承过来 3 创建函数及触发器,使插入的数据根据规则,插入到对应的子表中 -- 创建父表 CREATE TABLE a…...

Docker应用部署

文章目录Docker 应用部署一、部署MySQL二、部署Tomcat三、部署Nginx四、部署RedisDocker 应用部署 一、部署MySQL 搜索mysql镜像 docker search mysql拉取mysql镜像 docker pull mysql:5.6创建容器,设置端口映射、目录映射 # 在/root目录下创建mysql目录用于存…...

使用golang实现日志收集系统的logagent

整体架构 参考 七米老师的日志收集项目 主要用go实现logagent的部分,logagent的作用主要是实时监控日志追加的变化,并将变化发送到kafka中。 之前我们已经实现了 用go连接kafka并向其中发送数据,也实现了使用tail库监控日志追加操作。 我们…...

小红书点赞不显示怎么回事?小红书笔记评论被吞怎么办

小红书作为一个互联网产品,是一个软件。既然是软件就会有一定的程序漏洞,这是无法避免的。但是很多时候其实并不一定是漏洞的问题。今天就来和大家谈谈小红书点赞不显示怎么回事,小红书评论被吞又是怎么一回事,这些难道都是程序性…...

地址变换和缺页置换习题

1.设某进程页面的访问序列为4,3,2,1,4,3,5,4,3,2,1,5,当分配给该进程的内存页框数分别为3和4时,对于先进先出,最近最少使用,最佳页面置换算法,分别发生多少次缺页中断? 答: 分配的…...

PAT 乙级 1010 一元多项式求导(解题思路+AC代码)

题目: 设计函数求一元多项式的导数。(注:xn(n为整数)的一阶导数为nxn−1。) 输入格式: 以指数递降方式输入多项式非零项系数和指数(绝对值均为不超过 1000 的整数)。数字间以空格分…...

django filter 统计数量 按属性去重

在Django中,如果你想要根据某个属性对查询集进行去重并统计数量,你可以使用values()方法配合annotate()方法来实现。这里有两种常见的方法来完成这个需求: 方法1:使用annotate()和Count 假设你有一个模型Item,并且你想…...

WordPress插件:AI多语言写作与智能配图、免费AI模型、SEO文章生成

厌倦手动写WordPress文章?AI自动生成,效率提升10倍! 支持多语言、自动配图、定时发布,让内容创作更轻松! AI内容生成 → 不想每天写文章?AI一键生成高质量内容!多语言支持 → 跨境电商必备&am…...

[Java恶补day16] 238.除自身以外数组的乘积

给你一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O(n) 时间复杂度…...

Java线上CPU飙高问题排查全指南

一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发,后来由Pivotal Software Inc.(现为VMware子公司)接管。RabbitMQ 是一个开源的消息代理和队列服务器,用 Erlang 语言编写。广泛应用于各种分布…...

django blank 与 null的区别

1.blank blank控制表单验证时是否允许字段为空 2.null null控制数据库层面是否为空 但是,要注意以下几点: Django的表单验证与null无关:null参数控制的是数据库层面字段是否可以为NULL,而blank参数控制的是Django表单验证时字…...

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement

Cilium动手实验室: 精通之旅---13.Cilium LoadBalancer IPAM and L2 Service Announcement 1. LAB环境2. L2公告策略2.1 部署Death Star2.2 访问服务2.3 部署L2公告策略2.4 服务宣告 3. 可视化 ARP 流量3.1 部署新服务3.2 准备可视化3.3 再次请求 4. 自动IPAM4.1 IPAM Pool4.2 …...

Spring Security 认证流程——补充

一、认证流程概述 Spring Security 的认证流程基于 过滤器链(Filter Chain),核心组件包括 UsernamePasswordAuthenticationFilter、AuthenticationManager、UserDetailsService 等。整个流程可分为以下步骤: 用户提交登录请求拦…...

6️⃣Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙

Go 语言中的哈希、加密与序列化:通往区块链世界的钥匙 一、前言:离区块链还有多远? 区块链听起来可能遥不可及,似乎是只有密码学专家和资深工程师才能涉足的领域。但事实上,构建一个区块链的核心并不复杂,尤其当你已经掌握了一门系统编程语言,比如 Go。 要真正理解区…...

密码学基础——SM4算法

博客主页:christine-rr-CSDN博客 ​​​​专栏主页:密码学 📌 【今日更新】📌 对称密码算法——SM4 目录 一、国密SM系列算法概述 二、SM4算法 2.1算法背景 2.2算法特点 2.3 基本部件 2.3.1 S盒 2.3.2 非线性变换 ​编辑…...