TensorFlow 简单的二分类神经网络的训练和应用流程

展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括:
1. 数据准备与预处理
2. 构建模型
3. 编译模型
4. 训练模型
5. 评估模型
6. 模型应用与部署
加载和应用已训练的模型
1. 数据准备与预处理
在本例中,数据准备是通过两个 Numpy 数组来完成的:
x:输入特征,形状为(8, 2),包含 8 个数据点,每个数据点有 2 个特征。y:标签,形状为(8,),包含对应的 0 或 1 标签,表示每个输入点的类别。
x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1], [0.7, 0.7], [0.7, -0.7], [-0.7, -0.7], [-0.7, 0.7]])
y = np.array([1, 1, 1, 1, 0, 0, 0, 0])
2. 构建模型
使用 Keras 的 Sequential 模型来构建神经网络。此模型包含两个全连接层(Dense 层):
- 第一个
Dense层有 3 个单位,激活函数是 Sigmoid。 - 第二个
Dense层有 1 个单位,激活函数是 Sigmoid,输出层的激活函数将模型输出的值映射到 0 到 1 之间,适合二分类任务。
l1 = tf.keras.layers.Dense(units=3, activation='sigmoid')
l2 = tf.keras.layers.Dense(units=1, activation='sigmoid')
model = tf.keras.Sequential([l1, l2])
3. 编译模型
在编译阶段,我们选择了优化器、损失函数和评估指标:
- 优化器:
SGD(随机梯度下降),学习率设置为 0.9。 - 损失函数:
binary_crossentropy,适用于二分类任务。 - 评估指标:
accuracy,表示训练过程中对分类准确率的衡量。
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
4. 训练模型
通过 model.fit() 函数来训练模型。我们传入训练数据 x 和标签 y,并设置训练的 epoch 数量为 2000。
model.fit(x, y, epochs=2000)
5. 评估模型
在此示例中,评估部分通过训练后的 model 来进行,并没有显式写出 evaluate() 函数。评估通常是在训练之后,通过测试集或验证集对模型性能进行评估,具体可以使用 model.evaluate() 来查看损失和准确度。
6. 模型应用与部署
训练完成后,我们保存了训练好的模型。保存后的模型可以被加载和应用于新的数据集。
model.save('my_model.keras') # 保存模型
7.加载和应用已训练的模型
加载保存的模型,并用其对新数据进行预测。model.predict() 方法返回的是预测的概率值,我们通过设置阈值(如 0.9)将其转换为类别(0 或 1)。
model = tf.keras.models.load_model('my_model.keras') # 加载模型
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]]) # 新的输入数据
predictions = model.predict(nx) # 获取预测结果
print(predictions) # 输出概率# 将概率转化为类别
predicted_classes = (predictions > 0.9).astype(int)
print(predicted_classes) # 输出最终的类别预测
8.完整代码
test.py 训练模型
import tensorflow as tf
import numpy as np
# 创建int32类型的0维张量,即标量
l1=tf.keras.layers.Dense(units=3,activation='sigmoid')
l2=tf.keras.layers.Dense(units=1,activation='sigmoid')
model=tf.keras.Sequential([l1,l2])
sgd = tf.keras.optimizers.SGD(learning_rate=0.9)
model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])
x=np.array([[1,1],[1,-1],[-1,1],[-1,-1],[0.7,0.7],[0.7,-0.7],[-0.7,-0.7],[-0.7,0.7]])
y=np.array([1,1,1,1,0,0,0,0])
model.fit(x,y,epochs=2000)
# 保存训练好的模型(Keras 格式)
model.save('my_model.keras')
test2.py加载模型并进行预测:
import tensorflow as tf
import numpy as np# 加载训练好的模型
model = tf.keras.models.load_model('my_model.keras')# 预测数据
nx = np.array([[2, 2], [0.1, 0.1], [1.1, 1.2], [0.3, 0.3]])# 获取预测结果
predictions = model.predict(nx)# 输出预测结果
print(predictions)# 如果需要将概率转化为类别(0或1)
predicted_classes = (predictions > 0.9).astype(int)# 输出最终的类别预测
print(predicted_classes)
9.视频分享
初识TensorFlow
https://v.douyin.com/ifG2mmLH/
复制此链接,打开Dou音搜索,直接观看视频!
相关文章:
TensorFlow 简单的二分类神经网络的训练和应用流程
展示了一个简单的二分类神经网络的训练和应用流程。主要步骤包括: 1. 数据准备与预处理 2. 构建模型 3. 编译模型 4. 训练模型 5. 评估模型 6. 模型应用与部署 加载和应用已训练的模型 1. 数据准备与预处理 在本例中,数据准备是通过两个 Numpy 数…...
docker安装Redis:docker离线安装Redis、docker在线安装Redis、Redis镜像下载、Redis配置、Redis命令
一、镜像下载 1、在线下载 在一台能连外网的linux上执行docker镜像拉取命令 docker pull redis:7.4.0 2、离线包下载 两种方式: 方式一: -)在一台能连外网的linux上安装docker执行第一步的命令下载镜像 -)导出 # 导出镜像…...
Retrieval-Augmented Generation for Large Language Models: A Survey——(1)Overview
Retrieval-Augmented Generation for Large Language Models: A Survey——(1)Overview 文章目录 Retrieval-Augmented Generation for Large Language Models: A Survey——(1)Overview1. Introduction&Abstract1. LLM面临的问题2. RAG核心三要素3. RAG taxonomy 2. Overv…...
LabVIEW透镜多参数自动检测系统
在现代制造业中,提升产品质量检测的自动化水平是提高生产效率和准确性的关键。本文介绍了一个基于LabVIEW的透镜多参数自动检测系统,该系统能够在单一工位上完成透镜的多项质量参数检测,并实现透镜的自动搬运与分选,极大地提升了检…...
什么是Maxscript?为什么要学习Maxscript?
MAXScript是Autodesk 3ds Max的内置脚本语言,它是一种与3dsMax对话并使3dsMax执行某些操作的编程语言。它是一种脚本语言,这意味着您不需要编译代码即可运行。通过使用一系列基于文本的命令而不是使用UI操作,您可以完成许多使用UI操作无法完成的任务。 Maxscript是一种专有…...
Redis|前言
文章目录 什么是 Redis?Redis 主流功能与应用 什么是 Redis? Redis,Remote Dictionary Server(远程字典服务器)。Redis 是完全开源的,使用 ANSIC 语言编写,遵守 BSD 协议,是一个高性…...
128周二复盘(164)学习任天堂
1.设计相关 研究历史上某些武器数值,对一些设定进行参数修改。兼顾真实性,合理性,娱乐性。 学习宫本茂游戏思想,简单有趣-重玩性,风格化个性化-反拟真。对堆难度与内容的反思。 后续将学习岩田聪以及别的任天堂名人的…...
LeetCode:63. 不同路径 II
跟着carl学算法,本系列博客仅做个人记录,建议大家都去看carl本人的博客,写的真的很好的! 代码随想录 LeetCode:63. 不同路径 II 给定一个 m x n 的整数数组 grid。一个机器人初始位于 左上角(即 grid[0][0]…...
VUE之组件通信(一)
1、props 概述:props是使用频率最高的一种通信方式,常用与:父<——>子。 若 父传子:属性值是非函数。若 子传父:属性值是函数。 父组件: <template><div class"father">&l…...
ESP32-S3模组上跑通esp32-camera(41)
接前一篇文章:ESP32-S3模组上跑通esp32-camera(40) 一、OV5640初始化 2. 相机初始化及图像传感器配置 上一回继续对reset函数的后一段代码进行解析。为了便于理解和回顾,再次贴出reset函数源码,在components\esp32-camera\sensors\ov5640.c中,如下: static int reset…...
本地部署DeepSeek R1:打造专属私人AI助手指南
在当今人工智能蓬勃发展的浪潮中,DeepSeek R1模型的本地部署为用户带来了全新的体验。它不仅能够保障数据隐私,还具备与商业AI模型相媲美的出色性能。随着计算能力的不断提升以及开源AI社区的日益壮大,用户如今可以在本地运行高性能AI模型&am…...
Redis-布隆过滤器
文章目录 布隆过滤器的特点:实践布隆过滤器应用 布隆过滤器的特点: 就可以把布隆过滤器理解为一个set集合,我们可以通过add往里面添加元素,通过contains来判断是否包含某个元素。 布隆过滤器是一个很长的二进制向量和一系列随机映射函数。 可以用来检索…...
OpenCV 版本不兼容导致的问题
问题和解决方案 今天运行如下代码,发生了意外的错误,代码如下,其中输入的 frame 来自于 OpenCV 开启数据流的读取 """ cap cv2.VideoCapture(RTSP_URL) print("链接视频流完成") while True:ret, frame cap.rea…...
【视频+图文详解】HTML基础3-html常用标签
图文教程 html常用标签 常用标签 1. 文档结构 <!DOCTYPE html>:声明HTML文档类型。<html>:定义HTML文档的根元素。<head>:定义文档头部,包含元数据。<title>:设置网页标题,浏览…...
【B站保姆级视频教程:Jetson配置YOLOv11环境(五)Miniconda安装与配置】
Jetson配置YOLOv11环境(5)Miniconda安装与配置 文章目录 0. Anaconda vs Miniconda in Jetson1. 下载Miniconda32. 安装Miniconda33. 换源3.1 conda 换源3.2 pip 换源 4. 创建环境5. 设置默认启动环境 0. Anaconda vs Miniconda in Jetson Jetson 设备资…...
【PLL】杂散生成和调制
时钟生成 --》 数字系统 --》峰值抖动频率生成 --》无线系统 --》 频谱纯度、 周期信号的相位不确定性 随机抖动(random jitter, RJ)确定性抖动(deterministic jitter,DJ) 时域频域随机抖动积分相位噪声确定性抖动边带 杂散生成和…...
游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)
Unity Unity 首次发布于 2005 年,属于 Unity Technologies Unity 使用的开发技术有:C# Unity 的适用平台:PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域:开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…...
侯捷 C++ 课程学习笔记:深入理解 C++ 核心技术与实战应用
目录 引言 第一章:C 基础回顾 1.1 C 的历史与发展 1.2 C 的核心特性 1.3 C 的编译与执行 第二章:面向对象编程 2.1 类与对象 2.2 构造函数与析构函数 2.3 继承与多态 第三章:泛型编程与模板 3.1 函数模板 3.2 类模板 3.3 STL 容器…...
Java的Integer缓存池
Java的Integer缓冲池? Integer 缓存池主要为了提升性能和节省内存。根据实践发现大部分的数据操作都集中在值比较小的范围,因此缓存这些对象可以减少内存分配和垃圾回收的负担,提升性能。 在-128到 127范围内的 Integer 对象会被缓存和复用…...
【C++动态规划 离散化】1626. 无矛盾的最佳球队|2027
本文涉及知识点 C动态规划 离散化 LeetCode1626. 无矛盾的最佳球队 假设你是球队的经理。对于即将到来的锦标赛,你想组合一支总体得分最高的球队。球队的得分是球队中所有球员的分数 总和 。 然而,球队中的矛盾会限制球员的发挥,所以必须选…...
python 判断复杂包含
目录 python 判断复杂包含 a和b都是拍好序的: python 判断复杂包含 a[10,13,15] b[[9,11],[11,13],[13,16]] b的子项是区间,返回b中子区间包含a其中元素的子项 if __name__ __main__:a [10, 11, 15]b [[9, 11], [11, 13], [13, 16]]# 筛选出包含…...
Teleporters( Educational Codeforces Round 126 (Rated for Div. 2) )
Teleporters( Educational Codeforces Round 126 (Rated for Div. 2) ) There are n 1 n1 n1 teleporters on a straight line, located in points 0 0 0, a 1 a_1 a1, a 2 a_2 a2, a 3 a_3 a3, …, a n a_n an. It’s possible to tele…...
css-设置元素的溢出行为为可见overflow: visible;
1.前言 overflow 属性用于设置当元素的内容溢出其框时如何处理。 2. overflow overflow 属性的一些常见值: 1 visible:默认值。内容不会被剪裁,会溢出元素的框。 2 hidden:内容会被剪裁,不会显示溢出的部分。 3 sc…...
python-leetcode-从中序与后序遍历序列构造二叉树
106. 从中序与后序遍历序列构造二叉树 - 力扣(LeetCode) # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): # self.val val # self.left left # self.right r…...
绝对值线性化
函数中的绝对值线性化有多种方法,包括我之前的一篇博文. 前几天在小红书刷到一个帖子,一位网友提供了另外一种巧妙的方式,记录如下。 假如有一个绝对值表达式: y ∣ a x − b ∣ (1) y|ax-b|\tag{1} y∣ax−b∣(1) 令&#x…...
Java实战:图像浏览器
文章目录 1. 实战概述2. 知识准备3. 实现步骤3.1 创建Java项目3.2 创建图像浏览器类3.2.1 声明变量与常量3.2.2 创建构造方法3.2.3 创建初始化界面方法3.2.4 创建处理事件方法3.2.5 创建主方法3.2.6 查看完整代码 3.3 运行程序,查看结果 4. 实战小结5. 扩展练习 1. …...
SARIMA介绍
SARIMA模型,即季节性自回归积分移动平均模型(Seasonal Autoregressive Integrated Moving Average Model),是一种用于处理和预测具有明显季节性变化的时间序列数据的统计模型。它是ARIMA模型的一种扩展,通过引入额外的…...
I.MX6ULL 中断介绍上
i.MX6ULL是NXP(原Freescale)推出的一款基于ARM Cortex-A7内核的微处理器,广泛应用于嵌入式系统。在i.MX6ULL中,中断(Interrupt)是一种重要的机制,用于处理外部或内部事件,允许微处理…...
Spring Boot WebMvcConfigurer:定制你的 Web 应用
在构建基于Spring Boot的Web应用程序时,WebMvcConfigurer接口扮演着至关重要的角色。它允许开发者以一种简洁且非侵入的方式自定义Spring MVC的功能,而无需直接扩展框架的核心组件。本文将深入探讨WebMvcConfigurer的作用、如何实现其方法以及在实际项目…...
(即插即用模块-特征处理部分) 十九、(NeurIPS 2023) Prompt Block 提示生成 / 交互模块
文章目录 1、Prompt Block2、代码实现 paper:PromptIR: Prompting for All-in-One Blind Image Restoration Code:https://github.com/va1shn9v/PromptIR 1、Prompt Block 在解决现有图像恢复模型时,现有研究存在一些局限性: 现有…...
