当前位置: 首页 > news >正文

神经网络的建立-TensorFlow2.x

要学习深度强化学习,就要学会使用神经网络,建立神经网络可以使用TensorFlow和pytorch,今天先学习以TensorFlow建立网络。
直接上代码

import tensorflow as tf# 定义神经网络模型
model = tf.keras.models.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(10)
])# 编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 加载数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255# 训练模型
model.fit(x_train, y_train, epochs=5)# 评估模型
model.evaluate(x_test, y_test)

然后解释一下代码里的具体步骤

定义神经网络模型
当使用 tf.keras.models.Sequential 创建神经网络时,可以按顺序添加多个层,每个层都会顺序连接在一起,构成整个神经网络模型。在这个例子中,我们添加了两个全连接层(Dense层)和一个Dropout层。这个模型的结构如下所示:

输入层:由 tf.keras.layers.Dense(128, activation=‘relu’, input_shape=(784,)) 创建,包含128个神经元。输入数据的形状是 (None, 784),其中 None 表示任意的批次大小。激活函数为 ReLU。
Dropout层:由 tf.keras.layers.Dropout(0.2) 创建,其作用是随机断开一定比例的输入神经元,以防止过拟合。
输出层:由 tf.keras.layers.Dense(10) 创建,包含10个神经元,对应于10个分类。激活函数为空,因为我们将使用 logits 值来进行计算,而不是经过 softmax 转换后的概率。
因为我们没有指定激活函数的名称,所以默认情况下 Dense 层使用线性激活函数。由于我们需要 logits 值来计算交叉熵损失函数,因此输出层没有指定激活函数。

总的来说,这个模型是一个具有 1 个输入层,1 个输出层和 1 个 Dropout 层的简单全连接神经网络。

编译模型

在 TensorFlow 中,使用 compile() 方法来配置模型的训练过程,其中包括选择优化器、损失函数和评估指标等。

optimizer=‘adam’:指定使用 Adam 优化器进行模型训练。Adam 是一种常用的自适应学习率优化算法,可以更快地收敛和更好地处理不同的学习率。
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True):指定使用交叉熵损失函数来计算模型在训练期间的误差。在这里我们使用的是 SparseCategoricalCrossentropy,它适用于多分类问题且标签为整数编码的情况。由于输出层没有使用 softmax 激活函数,所以设置参数 from_logits=True,表示我们将使用 logits 值来计算交叉熵损失函数。
metrics=[‘accuracy’]:指定使用准确率作为模型的评估指标,以在训练期间监视模型的性能。我们可以指定多个评估指标,例如 metrics=[‘accuracy’, ‘mse’],以同时监视模型的准确率和均方误差。
这些配置将应用于后续的模型训练中。

然后直接从Keras这个高级API加载数据集,这里讲一下x_train和x_test

x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

这两行代码是将输入的训练数据 x_train 和测试数据 x_test 进行预处理,使它们具有相同的数据形状和数据类型,并进行了归一化处理。

x_train.reshape(60000, 784):将训练数据的形状从原来的 (60000, 28, 28) 重塑为 (60000, 784),其中 784 表示每个图像的像素数量,也就是将每个图像转换为一个长度为 784 的一维数组。
.astype(‘float32’):将数据类型转换为浮点型,因为在后续的归一化处理中需要进行除法操作。
/ 255:将像素值的范围从原来的 0 到 255 之间的整数转换为 0 到 1 之间的浮点数。这是一种常见的归一化方法,它可以使得输入数据的数值范围更加稳定,更容易被神经网络学习。
这样处理后,训练数据 x_train 和测试数据 x_test 都变成了形状为 (60000, 784) 和 (10000, 784) 的浮点数数组。这样的数据可以作为神经网络的输入,并可以更好地被模型学习。60000和10000分别表示训练和测试的数据量。

最后训练模型

model.fit(x_train, y_train, epochs=5) 是用来训练模型的代码。它将训练数据集 x_train 和训练标签 y_train 作为输入,并对模型进行多轮(5 轮)的训练。

具体来说,fit() 方法将对模型进行以下操作:

按照指定的轮数(即 epochs=5)对整个数据集进行多次迭代训练。
在每一轮训练中,将数据集划分为多个小批量数据,每个小批量包含一定数量的样本(默认情况下是 32 个样本)。
使用优化器(在 model.compile() 中指定)对模型进行优化,即更新模型的权重和偏置以最小化损失函数。
计算在每个小批量数据上的损失值和评估指标值,并在屏幕上输出模型的训练进度信息。
在每一轮训练结束后,使用测试数据集进行模型评估(在 model.evaluate() 中指定)。
通过反复迭代训练,模型的权重和偏置将不断被更新,以使得模型能够更好地适应数据集,并获得更好的性能。

评估模型model.evaluate(x_test, y_test)

相关文章:

神经网络的建立-TensorFlow2.x

要学习深度强化学习,就要学会使用神经网络,建立神经网络可以使用TensorFlow和pytorch,今天先学习以TensorFlow建立网络。 直接上代码 import tensorflow as tf# 定义神经网络模型 model tf.keras.models.Sequential([tf.keras.layers.Dense…...

python基于卷积神经网络实现自定义数据集训练与测试

注意: 如何更改图像尺寸在这篇文章中,修改完之后你就可以把你自己的数据集应用到网络。如果你的训练集与测试集也分别为30和5,并且样本类别也为3类,那么你只需要更改图像标签文件地址以及标签内容(如下面两图所示&…...

跟着LearnOpenGL学习3--四边形绘制

文章目录 一、前言二、元素缓冲对象三、完整代码四、绘制模式 一、前言 通过跟着LearnOpenGL学习2–三角形绘制一文,我们已经知道了怎么配置渲染管线,来绘制三角形; OpenGL主要处理三角形,当我们需要绘制别的图形时,…...

c#笔记-结构

装箱 结构是值类型。值类型不能继承其他类型,也不能被其他类型继承。 所以它的方法都是确定的,没有虚方法需要在运行时进行动态绑定。 值类型没有对象头,方法调用由编译器直接确定。 但是,如果使用引用类型变量(如接…...

Es分布式搜索引擎

目录 一、什么是ES? 二、什么是elk? 三、什么是倒排索引? 四、正向索引和倒排索引的优缺点对比 五、mysql数据库和es的区别? 六、索引库(es中的数据库表)操作有哪些? 八、ES分片存储原理 …...

open3d 裁剪点云

目录 1. crop_point_cloud 2. crop 3. crop_mesh 1. crop_point_cloud 关键函数 chair vol.crop_point_cloud(pcd) # vol: SelectionPolygonVolume import open3d as o3dif __name__ "__main__":# 1. read pcdprint("Load a ply point cloud, crop it…...

如何对第三方相同请求进行筛选过滤

文章目录 问题背景处理思路注意事项代码实现 问题背景 公司内部多个系统共用一套用户体系库,对外(钉钉)我们是两个客户身份(这里是根据系统来的),例如当第三方服务向我们发起用户同步请求:是一个更新用户操作,它会同时发送一个 d…...

Go RPC

目录 文章目录 Go RPCHTTP RPCTCP RPCJSON RPC Go RPC Go 标准包中已经提供了对 RPC 的支持,而且支持三个级别的 RPC:TCP、HTTP、JSONRPC。但 Go 的 RPC 包是独一无二的 RPC,它和传统的 RPC 系统不同,它只支持 Go 开发的服务器与…...

真正的智能不仅仅是一个技术问题

智能并不是单一的技术问题,而是一个包括技术、人类智慧、社会制度和文化等多个方面的综合体,常常涉及技术变革、系统演变、运行方式创新、组织适应。智能是指人类的思考、判断、决策和创造等高级认知能力,可以通过技术手段来实现增强和扩展。…...

【数据结构】复杂度包装泛型

目录 1.时间和空间复杂度 1.1时间复杂度 1.2空间复杂度 2.包装类 2.1基本数据类型和对应的包装类 2.2装箱和拆箱 //阿里巴巴面试题 3.泛型 3.1擦除机制 3.2泛型的上界 1.时间和空间复杂度 1.1时间复杂度 定义:一个算法所花费的时间与其语句的执行次数成…...

Ae:绘画面板

Ae菜单:窗口/绘画 Paint 快捷键:Ctrl 8 绘画工具(画笔工具、仿制图章工具及橡皮擦工具)仅能工作在图层面板上。在使用绘画工具之前,建议先在绘画 Paint面板中查看或进行相关设置。 说明: 如果要在绘画描边…...

常见的锁和zookeeper

zookeeper 本文由 简悦 SimpRead 转码, 原文地址 zhuanlan.zhihu.com 前言 只有光头才能变强。 文本已收录至我的 GitHub 仓库,欢迎 Star:https://github.com/ZhongFuCheng3y/3y 上次写了一篇 什么是消息队列?以后,本来…...

经验总结:(Redis NoSQL数据库快速入门)

一、Nosql概述 为什么使用Nosql 1、单机Mysql时代 90年代,一个网站的访问量一般不会太大,单个数据库完全够用。随着用户增多,网站出现以下问题 数据量增加到一定程度,单机数据库就放不下了数据的索引(B Tree),一个机…...

form表单与模板引擎

文章目录 一、form表单的基本使用1、什么是表单2、表单的组成部分3、 <form>标签的属性4、表单的同步提交及缺点&#xff08;1&#xff09; 什么是表单的同步提交&#xff08;2&#xff09; 表单同步提交的缺点&#xff08;3&#xff09; 如何解决表单同步提交的缺点 二、…...

医院检验信息管理系统源码(云LIS系统源码)JQuery、EasyUI

云LIS系统是一种医疗实验室信息管理系统&#xff0c;提供全面的实验室信息管理解决方案。它的主要功能包括样本管理、检测流程管理、报告管理、质量控制、数据分析和仪器管理等。 云LIS源码技术说明&#xff1a; 技术架构&#xff1a;Asp.NET CORE 3.1 MVC SQLserver Redis等…...

React 组件

文章目录 React 组件复合组件 React 组件 本节将讨论如何使用组件使得我们的应用更容易来管理。 接下来我们封装一个输出 “Hello World&#xff01;” 的组件&#xff0c;组件名为 HelloMessage&#xff1a; React 实例 <!DOCTYPE html> <html> <head> &…...

硕士学位论文的几种常见节奏

摘要: 本文描述硕士学位论文的几种目录结构, 特别针对机器学习方向. 1. 基础版 XX算法及其在YY中的应用 针对情况: 只有一篇小论文支撑. 第 1 章: 引言 ( > 5页) 1.1 背景及意义 (应用背景、研究意义, 2 页) 1.2 研究进展及趋势 (算法方面, 2 页) 1.3 论文结构 (1 页) 第 …...

找兄弟单词

描述 定义一个单词的“兄弟单词”为&#xff1a;交换该单词字母顺序&#xff08;注&#xff1a;可以交换任意次&#xff09;&#xff0c;而不添加、删除、修改原有的字母就能生成的单词。 兄弟单词要求和原来的单词不同。例如&#xff1a; ab 和 ba 是兄弟单词。 ab 和 ab 则不…...

python字典翻转教学

目录 第1关 创建大学英语四级单词字典 第2关 合并大学英语四六级词汇字典 第3关 查单词输出中文释义 第4关 删除字典中特定字母开头的单词 第5关 单词英汉记忆训练 第1关 创建大学英语四级单词字典 本关任务&#xff1a;编写一个能创建大学英语四级单词字典的小程序。 测…...

sentinel 随笔 3-降级处理

0. 像喝点东西&#xff0c;但不知道喝什么 先来段源码&#xff0c;看一下 我们在dashboard 录入的降级规则&#xff0c;都映射到哪些字段上 package com.alibaba.csp.sentinel.slots.block.degrade;public class DegradeRule extends AbstractRule {public DegradeRule(String…...

大型活动交通拥堵治理的视觉算法应用

大型活动下智慧交通的视觉分析应用 一、背景与挑战 大型活动&#xff08;如演唱会、马拉松赛事、高考中考等&#xff09;期间&#xff0c;城市交通面临瞬时人流车流激增、传统摄像头模糊、交通拥堵识别滞后等问题。以演唱会为例&#xff0c;暖城商圈曾因观众集中离场导致周边…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具

文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

CRMEB 框架中 PHP 上传扩展开发:涵盖本地上传及阿里云 OSS、腾讯云 COS、七牛云

目前已有本地上传、阿里云OSS上传、腾讯云COS上传、七牛云上传扩展 扩展入口文件 文件目录 crmeb\services\upload\Upload.php namespace crmeb\services\upload;use crmeb\basic\BaseManager; use think\facade\Config;/*** Class Upload* package crmeb\services\upload* …...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

2025年渗透测试面试题总结-腾讯[实习]科恩实验室-安全工程师(题目+回答)

安全领域各种资源&#xff0c;学习文档&#xff0c;以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具&#xff0c;欢迎关注。 目录 腾讯[实习]科恩实验室-安全工程师 一、网络与协议 1. TCP三次握手 2. SYN扫描原理 3. HTTPS证书机制 二…...

关于uniapp展示PDF的解决方案

在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项&#xff1a; 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库&#xff1a; npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...

Spring AI Chat Memory 实战指南:Local 与 JDBC 存储集成

一个面向 Java 开发者的 Sring-Ai 示例工程项目&#xff0c;该项目是一个 Spring AI 快速入门的样例工程项目&#xff0c;旨在通过一些小的案例展示 Spring AI 框架的核心功能和使用方法。 项目采用模块化设计&#xff0c;每个模块都专注于特定的功能领域&#xff0c;便于学习和…...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例&#xff0c;Webpack.config.js它可能的配置和含义如下&#xff1a; 前言 Module Federation 的Webpack.config.js核心配置包括&#xff1a; name filename&#xff08;定义应用标识&#xff09; remotes&#xff08;引用远程模块&#xff0…...

高防服务器价格高原因分析

高防服务器的价格较高&#xff0c;主要是由于其特殊的防御机制、硬件配置、运营维护等多方面的综合成本。以下从技术、资源和服务三个维度详细解析高防服务器昂贵的原因&#xff1a; 一、硬件与技术投入 大带宽需求 DDoS攻击通过占用大量带宽资源瘫痪目标服务器&#xff0c;因此…...