神经网络的建立-TensorFlow2.x
要学习深度强化学习,就要学会使用神经网络,建立神经网络可以使用TensorFlow和pytorch,今天先学习以TensorFlow建立网络。
直接上代码
import tensorflow as tf# 定义神经网络模型
model = tf.keras.models.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10)
])# 编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 加载数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255# 训练模型
model.fit(x_train, y_train, epochs=5)# 评估模型
model.evaluate(x_test, y_test)
然后解释一下代码里的具体步骤
定义神经网络模型
当使用 tf.keras.models.Sequential 创建神经网络时,可以按顺序添加多个层,每个层都会顺序连接在一起,构成整个神经网络模型。在这个例子中,我们添加了两个全连接层(Dense层)和一个Dropout层。这个模型的结构如下所示:
输入层:由 tf.keras.layers.Dense(128, activation=‘relu’, input_shape=(784,)) 创建,包含128个神经元。输入数据的形状是 (None, 784),其中 None 表示任意的批次大小。激活函数为 ReLU。
Dropout层:由 tf.keras.layers.Dropout(0.2) 创建,其作用是随机断开一定比例的输入神经元,以防止过拟合。
输出层:由 tf.keras.layers.Dense(10) 创建,包含10个神经元,对应于10个分类。激活函数为空,因为我们将使用 logits 值来进行计算,而不是经过 softmax 转换后的概率。
因为我们没有指定激活函数的名称,所以默认情况下 Dense 层使用线性激活函数。由于我们需要 logits 值来计算交叉熵损失函数,因此输出层没有指定激活函数。
总的来说,这个模型是一个具有 1 个输入层,1 个输出层和 1 个 Dropout 层的简单全连接神经网络。
编译模型
在 TensorFlow 中,使用 compile() 方法来配置模型的训练过程,其中包括选择优化器、损失函数和评估指标等。
optimizer=‘adam’:指定使用 Adam 优化器进行模型训练。Adam 是一种常用的自适应学习率优化算法,可以更快地收敛和更好地处理不同的学习率。
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True):指定使用交叉熵损失函数来计算模型在训练期间的误差。在这里我们使用的是 SparseCategoricalCrossentropy,它适用于多分类问题且标签为整数编码的情况。由于输出层没有使用 softmax 激活函数,所以设置参数 from_logits=True,表示我们将使用 logits 值来计算交叉熵损失函数。
metrics=[‘accuracy’]:指定使用准确率作为模型的评估指标,以在训练期间监视模型的性能。我们可以指定多个评估指标,例如 metrics=[‘accuracy’, ‘mse’],以同时监视模型的准确率和均方误差。
这些配置将应用于后续的模型训练中。
然后直接从Keras这个高级API加载数据集,这里讲一下x_train和x_test
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
这两行代码是将输入的训练数据 x_train 和测试数据 x_test 进行预处理,使它们具有相同的数据形状和数据类型,并进行了归一化处理。
x_train.reshape(60000, 784):将训练数据的形状从原来的 (60000, 28, 28) 重塑为 (60000, 784),其中 784 表示每个图像的像素数量,也就是将每个图像转换为一个长度为 784 的一维数组。
.astype(‘float32’):将数据类型转换为浮点型,因为在后续的归一化处理中需要进行除法操作。
/ 255:将像素值的范围从原来的 0 到 255 之间的整数转换为 0 到 1 之间的浮点数。这是一种常见的归一化方法,它可以使得输入数据的数值范围更加稳定,更容易被神经网络学习。
这样处理后,训练数据 x_train 和测试数据 x_test 都变成了形状为 (60000, 784) 和 (10000, 784) 的浮点数数组。这样的数据可以作为神经网络的输入,并可以更好地被模型学习。60000和10000分别表示训练和测试的数据量。
最后训练模型
model.fit(x_train, y_train, epochs=5) 是用来训练模型的代码。它将训练数据集 x_train 和训练标签 y_train 作为输入,并对模型进行多轮(5 轮)的训练。
具体来说,fit() 方法将对模型进行以下操作:
按照指定的轮数(即 epochs=5)对整个数据集进行多次迭代训练。
在每一轮训练中,将数据集划分为多个小批量数据,每个小批量包含一定数量的样本(默认情况下是 32 个样本)。
使用优化器(在 model.compile() 中指定)对模型进行优化,即更新模型的权重和偏置以最小化损失函数。
计算在每个小批量数据上的损失值和评估指标值,并在屏幕上输出模型的训练进度信息。
在每一轮训练结束后,使用测试数据集进行模型评估(在 model.evaluate() 中指定)。
通过反复迭代训练,模型的权重和偏置将不断被更新,以使得模型能够更好地适应数据集,并获得更好的性能。
评估模型model.evaluate(x_test, y_test)
相关文章:
神经网络的建立-TensorFlow2.x
要学习深度强化学习,就要学会使用神经网络,建立神经网络可以使用TensorFlow和pytorch,今天先学习以TensorFlow建立网络。 直接上代码 import tensorflow as tf# 定义神经网络模型 model tf.keras.models.Sequential([tf.keras.layers.Dense…...
python基于卷积神经网络实现自定义数据集训练与测试
注意: 如何更改图像尺寸在这篇文章中,修改完之后你就可以把你自己的数据集应用到网络。如果你的训练集与测试集也分别为30和5,并且样本类别也为3类,那么你只需要更改图像标签文件地址以及标签内容(如下面两图所示&…...
跟着LearnOpenGL学习3--四边形绘制
文章目录 一、前言二、元素缓冲对象三、完整代码四、绘制模式 一、前言 通过跟着LearnOpenGL学习2–三角形绘制一文,我们已经知道了怎么配置渲染管线,来绘制三角形; OpenGL主要处理三角形,当我们需要绘制别的图形时,…...
c#笔记-结构
装箱 结构是值类型。值类型不能继承其他类型,也不能被其他类型继承。 所以它的方法都是确定的,没有虚方法需要在运行时进行动态绑定。 值类型没有对象头,方法调用由编译器直接确定。 但是,如果使用引用类型变量(如接…...
Es分布式搜索引擎
目录 一、什么是ES? 二、什么是elk? 三、什么是倒排索引? 四、正向索引和倒排索引的优缺点对比 五、mysql数据库和es的区别? 六、索引库(es中的数据库表)操作有哪些? 八、ES分片存储原理 …...
open3d 裁剪点云
目录 1. crop_point_cloud 2. crop 3. crop_mesh 1. crop_point_cloud 关键函数 chair vol.crop_point_cloud(pcd) # vol: SelectionPolygonVolume import open3d as o3dif __name__ "__main__":# 1. read pcdprint("Load a ply point cloud, crop it…...
如何对第三方相同请求进行筛选过滤
文章目录 问题背景处理思路注意事项代码实现 问题背景 公司内部多个系统共用一套用户体系库,对外(钉钉)我们是两个客户身份(这里是根据系统来的),例如当第三方服务向我们发起用户同步请求:是一个更新用户操作,它会同时发送一个 d…...
Go RPC
目录 文章目录 Go RPCHTTP RPCTCP RPCJSON RPC Go RPC Go 标准包中已经提供了对 RPC 的支持,而且支持三个级别的 RPC:TCP、HTTP、JSONRPC。但 Go 的 RPC 包是独一无二的 RPC,它和传统的 RPC 系统不同,它只支持 Go 开发的服务器与…...
真正的智能不仅仅是一个技术问题
智能并不是单一的技术问题,而是一个包括技术、人类智慧、社会制度和文化等多个方面的综合体,常常涉及技术变革、系统演变、运行方式创新、组织适应。智能是指人类的思考、判断、决策和创造等高级认知能力,可以通过技术手段来实现增强和扩展。…...
【数据结构】复杂度包装泛型
目录 1.时间和空间复杂度 1.1时间复杂度 1.2空间复杂度 2.包装类 2.1基本数据类型和对应的包装类 2.2装箱和拆箱 //阿里巴巴面试题 3.泛型 3.1擦除机制 3.2泛型的上界 1.时间和空间复杂度 1.1时间复杂度 定义:一个算法所花费的时间与其语句的执行次数成…...
Ae:绘画面板
Ae菜单:窗口/绘画 Paint 快捷键:Ctrl 8 绘画工具(画笔工具、仿制图章工具及橡皮擦工具)仅能工作在图层面板上。在使用绘画工具之前,建议先在绘画 Paint面板中查看或进行相关设置。 说明: 如果要在绘画描边…...
常见的锁和zookeeper
zookeeper 本文由 简悦 SimpRead 转码, 原文地址 zhuanlan.zhihu.com 前言 只有光头才能变强。 文本已收录至我的 GitHub 仓库,欢迎 Star:https://github.com/ZhongFuCheng3y/3y 上次写了一篇 什么是消息队列?以后,本来…...
经验总结:(Redis NoSQL数据库快速入门)
一、Nosql概述 为什么使用Nosql 1、单机Mysql时代 90年代,一个网站的访问量一般不会太大,单个数据库完全够用。随着用户增多,网站出现以下问题 数据量增加到一定程度,单机数据库就放不下了数据的索引(B Tree),一个机…...
form表单与模板引擎
文章目录 一、form表单的基本使用1、什么是表单2、表单的组成部分3、 <form>标签的属性4、表单的同步提交及缺点(1) 什么是表单的同步提交(2) 表单同步提交的缺点(3) 如何解决表单同步提交的缺点 二、…...
医院检验信息管理系统源码(云LIS系统源码)JQuery、EasyUI
云LIS系统是一种医疗实验室信息管理系统,提供全面的实验室信息管理解决方案。它的主要功能包括样本管理、检测流程管理、报告管理、质量控制、数据分析和仪器管理等。 云LIS源码技术说明: 技术架构:Asp.NET CORE 3.1 MVC SQLserver Redis等…...
React 组件
文章目录 React 组件复合组件 React 组件 本节将讨论如何使用组件使得我们的应用更容易来管理。 接下来我们封装一个输出 “Hello World!” 的组件,组件名为 HelloMessage: React 实例 <!DOCTYPE html> <html> <head> &…...
硕士学位论文的几种常见节奏
摘要: 本文描述硕士学位论文的几种目录结构, 特别针对机器学习方向. 1. 基础版 XX算法及其在YY中的应用 针对情况: 只有一篇小论文支撑. 第 1 章: 引言 ( > 5页) 1.1 背景及意义 (应用背景、研究意义, 2 页) 1.2 研究进展及趋势 (算法方面, 2 页) 1.3 论文结构 (1 页) 第 …...
找兄弟单词
描述 定义一个单词的“兄弟单词”为:交换该单词字母顺序(注:可以交换任意次),而不添加、删除、修改原有的字母就能生成的单词。 兄弟单词要求和原来的单词不同。例如: ab 和 ba 是兄弟单词。 ab 和 ab 则不…...
python字典翻转教学
目录 第1关 创建大学英语四级单词字典 第2关 合并大学英语四六级词汇字典 第3关 查单词输出中文释义 第4关 删除字典中特定字母开头的单词 第5关 单词英汉记忆训练 第1关 创建大学英语四级单词字典 本关任务:编写一个能创建大学英语四级单词字典的小程序。 测…...
sentinel 随笔 3-降级处理
0. 像喝点东西,但不知道喝什么 先来段源码,看一下 我们在dashboard 录入的降级规则,都映射到哪些字段上 package com.alibaba.csp.sentinel.slots.block.degrade;public class DegradeRule extends AbstractRule {public DegradeRule(String…...
【JavaEE】-- HTTP
1. HTTP是什么? HTTP(全称为"超文本传输协议")是一种应用非常广泛的应用层协议,HTTP是基于TCP协议的一种应用层协议。 应用层协议:是计算机网络协议栈中最高层的协议,它定义了运行在不同主机上…...
MFC内存泄露
1、泄露代码示例 void X::SetApplicationBtn() {CMFCRibbonApplicationButton* pBtn GetApplicationButton();// 获取 Ribbon Bar 指针// 创建自定义按钮CCustomRibbonAppButton* pCustomButton new CCustomRibbonAppButton();pCustomButton->SetImage(IDB_BITMAP_Jdp26)…...
《用户共鸣指数(E)驱动品牌大模型种草:如何抢占大模型搜索结果情感高地》
在注意力分散、内容高度同质化的时代,情感连接已成为品牌破圈的关键通道。我们在服务大量品牌客户的过程中发现,消费者对内容的“有感”程度,正日益成为影响品牌传播效率与转化率的核心变量。在生成式AI驱动的内容生成与推荐环境中࿰…...
IT供电系统绝缘监测及故障定位解决方案
随着新能源的快速发展,光伏电站、储能系统及充电设备已广泛应用于现代能源网络。在光伏领域,IT供电系统凭借其持续供电性好、安全性高等优势成为光伏首选,但在长期运行中,例如老化、潮湿、隐裂、机械损伤等问题会影响光伏板绝缘层…...
第 86 场周赛:矩阵中的幻方、钥匙和房间、将数组拆分成斐波那契序列、猜猜这个单词
Q1、[中等] 矩阵中的幻方 1、题目描述 3 x 3 的幻方是一个填充有 从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等。 给定一个由整数组成的row x col 的 grid,其中有多少个 3 3 的 “幻方” 子矩阵&am…...
ip子接口配置及删除
配置永久生效的子接口,2个IP 都可以登录你这一台服务器。重启不失效。 永久的 [应用] vi /etc/sysconfig/network-scripts/ifcfg-eth0修改文件内内容 TYPE"Ethernet" BOOTPROTO"none" NAME"eth0" DEVICE"eth0" ONBOOT&q…...
在web-view 加载的本地及远程HTML中调用uniapp的API及网页和vue页面是如何通讯的?
uni-app 中 Web-view 与 Vue 页面的通讯机制详解 一、Web-view 简介 Web-view 是 uni-app 提供的一个重要组件,用于在原生应用中加载 HTML 页面: 支持加载本地 HTML 文件支持加载远程 HTML 页面实现 Web 与原生的双向通讯可用于嵌入第三方网页或 H5 应…...
MySQL 8.0 事务全面讲解
以下是一个结合两次回答的 MySQL 8.0 事务全面讲解,涵盖了事务的核心概念、操作示例、失败回滚、隔离级别、事务性 DDL 和 XA 事务等内容,并修正了查看隔离级别的命令。 MySQL 8.0 事务全面讲解 一、事务的核心概念(ACID) 事务是…...
Golang——6、指针和结构体
指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...
逻辑回归暴力训练预测金融欺诈
简述 「使用逻辑回归暴力预测金融欺诈,并不断增加特征维度持续测试」的做法,体现了一种逐步建模与迭代验证的实验思路,在金融欺诈检测中非常有价值,本文作为一篇回顾性记录了早年间公司给某行做反欺诈预测用到的技术和思路。百度…...
