TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(7)——分类任务详解
- 0. 前言
- 1. 分类任务
- 1.1 分类任务简介
- 1.2 分类与回归的区别
- 2. 逻辑回归
- 3. 使用 TensorFlow 实现逻辑回归
- 小结
- 系列链接
0. 前言
分类任务 (Classification Task
) 是机器学习中的一种监督学习问题,其目的是将输入数据(特征向量)映射到离散的类别标签。广泛应用于如文本分类、图像识别、垃圾邮件检测、医学诊断等多种领域。
1. 分类任务
1.1 分类任务简介
分类任务的目标是通过训练数据学习一个模型,使得对于新的输入数据能够预测其所属的类别。输入数据是模型的自变量,通常是特征向量 x = [ x 1 , x 2 , … , x n ] ) {x} = [x_1, x_2, \dots, x_n]) x=[x1,x2,…,xn]),其中 n n n 是特征的维度,每个特征可能是连续值(如温度、年龄)或离散值(如颜色、性别)。输出是一个类别标签,表示每个输入数据点的所属类别,对于二分类任务,输出标签通常为 0
或 1
;而对于多分类任务,标签的数量可以是多个类别,例如 0
、1
、2
、3
等。
根据类别的数量不同,可以将分类任务归为不同类型:
- 二分类 (
Binary Classification
):输出只有两个类别,例如“是”与“否” - 多分类 (
Multiclass Classification
):输出包含多个类别标签,适用于每个样本属于多个可能类别中的一个的任务,例如“猫”、“狗”、“狮子”、“大象”等 - 多标签分类 (
Multilabel Classification
):与传统的单一类别分类不同,每个样本可以同时属于多个类别
1.2 分类与回归的区别
回归和分类任务之间的区别:
- 在分类任务中,数据被分成不同的类别,而在回归中,目标是根据给定的数据得到一个连续值。例如,识别手写数字的任务属于分类任务,所有的手写数字都属于
0
到9
之间的某个数字;而根据不同的输入变量预测房屋价格则属于回归任务 - 在分类任务中,模型的目标是找到分隔不同类别的决策边界;而在回归任务中,模型的目标是逼近一个适合输入输出关系的函数。
分类和回归任务的不同之处如下图所示。在分类中,我们需要找到分隔类别的线(或平面,或超平面)。在回归中,目标是找到一条(或一个平面,或一个超平面)拟合给定输入与输出关系的线。
2. 逻辑回归
逻辑回归 (Logistic regression
) 用于确定事件发生的概率。通常,事件表示为分类的因变量。事件发生的概率使用 sigmoid
(或logit
)函数表示:
P ^ ( Y ^ = 1 ∣ X = x ) = 1 1 + e − ( b + W T x ) \hat P(\hat Y=1|X=x)=\frac 1{1+e^{-(b+W^Tx)}} P^(Y^=1∣X=x)=1+e−(b+WTx)1
目标是估计权重 W = { w 1 , w 2 , . . . , w n } W=\{w_1,w_2,...,w_n\} W={w1,w2,...,wn} 和偏置项 b b b。在逻辑回归中,系数可以使用最大似然估计或随机梯度下降来估计。如果 p p p 是输入数据样本的总数,损失通常定义为交叉熵项:
l o s s = ∑ i = 1 p Y i l o g ( Y ^ i ) + ( 1 − Y i ) l o g ( 1 − Y ^ i ) loss=\sum_{i=1}^pY_ilog(\hat Y_i)+(1-Y_i)log(1-\hat Y_i) loss=i=1∑pYilog(Y^i)+(1−Yi)log(1−Y^i)
逻辑回归用于分类问题。例如,在分析医疗数据时,我们可以使用逻辑回归来分类一个人是否患有癌症。如果输出的分类变量具有两个或多个,可以使用多分类逻辑回归。对于多分类逻辑回归,交叉熵损失函数可以改写为:
l o s s = ∑ i = 1 p ∑ j = 1 k Y i j l o g Y ^ i j loss=\sum_{i=1}^p\sum_{j=1}^kY_{ij}log\hat Y_{ij} loss=i=1∑pj=1∑kYijlogY^ij
其中 k k k 是类别总数。了解了逻辑回归的原理后,接下来,将其应用于具体实践中。
3. 使用 TensorFlow 实现逻辑回归
接下来,使用 TensorFlow
实现逻辑回归,对 MNIST
手写数字进行分类。MNIST
数据集包含手写数字的图像,每个图像都有一个标签值(介于 0
到 9
之间)标注图像中的数字值。因此,属于多类别分类问题。
为了实现逻辑回归,构建一个仅包含一个全连接层的模型。输出中的每个类别由一个神经元表示,由于我们有 10
个类别,输出层的神经元数为 10
。逻辑回归中使用的概率函数类似于 sigmoid
激活函数,因此,模型使用 sigmoid
激活。接下来,构建模型。
(1) 首先,导入所需库。由于全连接层接收的输入为一维数据,因此使用 Flatten
层,用于将 MNIST
数据集中的 28 x 28
二维输入图像调整为一个包含 784
个元素的一维数组:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow.keras as K
from tensorflow.keras.layers import Dense, Flatten
(2) 从 tensorflow.keras
数据集中获取 MNIST
输入数据:
((train_data, train_labels),(test_data, test_labels)) = tf.keras.datasets.mnist.load_data()
(3) 对数据进行预处理。对图像进行归一化,MNIST
数据集的图像是灰度图像,每个像素的强度值介于 0
到 255
之间,将其除以 255
,使数值范围在 0
到 1
之间:
train_data = train_data/np.float32(255)
train_labels = train_labels.astype(np.int32)
test_data = test_data/np.float32(255)
test_labels = test_labels.astype(np.int32)
(4) 定义模型,模型只有一个具有 10
个神经元的 Dense
层,输入大小为 784
,从模型摘要的输出中可以看到,只有 Dense
层具有可训练的参数:
model = K.Sequential([# Dense(64, activation='relu'),# Dense(32, activation='relu'),Flatten(input_shape=(28, 28)),Dense(10, activation='sigmoid')
])
print(model.summary())
(5) 因为标签是整数值,因此使用 SparseCategoricalCrossentropy
损失函数,设置 logits
参数为 True
。选择 Adam
优化器,此外,定义准确率作为在训练过程中需要记录的指标。模型训练 50
个 epochs
,使用 80:20
的比例拆分训练-验证集:
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
history = model.fit(x=train_data,y=train_labels, epochs=50, verbose=1, validation_split=0.2)
(6) 绘制损失曲线观察模型性能表现。可以看到随着 epoch
的增加,训练损失降低的同时,验证损失逐渐增加,因此模型出现过拟合,可以通过添加隐藏层来改善模型性能:
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
(7) 为了更好地理解结果,构建两个实用函数,用于可视化手写数字以及模型输出的 10
个神经元的概率:
predictions = model.predict(test_data)def plot_image(i, predictions_array, true_label, img):true_label, img = true_label[i], img[i]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(img, cmap=plt.cm.binary)predicted_label = np.argmax(predictions_array)if predicted_label == true_label:color = 'blue'else:color = 'red'plt.xlabel("Pred {} Conf: {:2.0f}% True ({})".format(predicted_label,100*np.max(predictions_array),true_label),color=color)def plot_value_array(i, predictions_array, true_label):true_label = true_label[i]plt.grid(False)plt.xticks(range(10))plt.yticks([])thisplot = plt.bar(range(10), predictions_array, color="#777777")plt.ylim([0, 1])predicted_label = np.argmax(predictions_array)thisplot[predicted_label].set_color('red')thisplot[true_label].set_color('blue')
(8) 绘制预测结果:
i = 56
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_data)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i], test_labels)
plt.show()
左侧的图像是手写数字图像,图像下方显示了预测的标签、预测的置信度以及真实标签。右侧的图像显示了 10
个神经元输出的概率(逻辑输出),可以看到代表数字 4
的神经元具有最高的概率:
(9) 为了保持逻辑回归的特性,以上代码仅使用了一个包含 sigmoid
激活函数的 Dense
层。为了获得更好的性能,可以添加 Dense
层并使用 softmax
作为最终的激活函数,以下模型在验证数据集上能够达到 97%
的准确率:
better_model = K.Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),#Dense(64, activation='relu'),Dense(10, activation='softmax')
])
better_model.summary()better_model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])history = better_model.fit(x=train_data,y=train_labels, epochs=10, verbose=1, validation_split=0.2)plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()predictions = better_model.predict(test_data)
i = 0
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_data)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i], test_labels)
plt.show()
我们可以尝试添加更多隐藏层,或者修改隐藏层中神经元的数量,或者修改优化器,以更好地理解这些参数对模型性能的影响。
小结
分类任务是机器学习中最常见的任务之一,广泛应用于各个领域。成功的分类任务不仅需要选择合适的算法,还需要对数据进行深入的预处理和特征工程。在本节中,我们首先介绍了分类任务及其与回归任务的区别,然后介绍了用于分类任务的逻辑回归技术,并使用 TensorFlow
实现了逻辑回归模型。
系列链接
TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解
相关文章:

TensorFlow深度学习实战(7)——分类任务详解
TensorFlow深度学习实战(7)——分类任务详解 0. 前言1. 分类任务1.1 分类任务简介1.2 分类与回归的区别 2. 逻辑回归3. 使用 TensorFlow 实现逻辑回归小结系列链接 0. 前言 分类任务 (Classification Task) 是机器学习中的一种监督学习问题,…...

动态规划问题——青蛙跳台阶案例分析
问题描述: 一只青蛙要跳上n级台阶,它每次可以跳 1级或者2级。问:青蛙有多少种不同的跳法可以跳完这些台阶? 举个例子: 假设台阶数 n 3 ,我们来看看青蛙有多少种跳法。 可能的跳法: 1. 跳1级…...

element-ui使用el-table,保留字段前的空白
项目名称项目编号1、XXXXX1111111111111111111 1.1 XXXXX11111111111111222222222 如上表格中,实现项目名称字段1.1前空白的效果。 从JAVA返回的数据带有空白,即数据库中插入的数据带有空白。 原先写法: <el-table><el-tabl…...

kamailio中路由模块汇总
功能模块描述请求路由 (request_route)主要处理进入的SIP请求,包含初步检查、NAT检测、CANCEL请求处理、重传处理等。处理通过REQINIT、NATDETECT、RELAY等子模块的调用。CANCEL处理对CANCEL请求进行处理,包括更新对话状态并检查事务。如果事务检查通过&…...

如何使用 DeepSeek 搭建本地知识库
使用 DeepSeek 搭建本地知识库可以帮助您高效管理和检索本地文档、数据或知识资源。以下是详细的步骤指南: 1. 准备工作 (1) 安装 DeepSeek 确保您的系统已安装 Python 3.8 或更高版本。使用 pip 安装 DeepSeek: bash pip install deepseek (2) 准备…...

网络HTTP详细讲解
学习目标 什么是HTTPHTTP的请求和响应常见的HTTP状态码HTTP的安全性 什么是HTTP?HTTP的请求和响应,常见的HTTP状态码,HTTP的安全性 什么是HTTP HTTP(HyperText Transfer Protocol,超文本传输协议)是一种用…...

《Origin画百图》之边际分布曲线图
《Origin画百图》第六集——边际分布曲线图 入门操作可看《30秒,带你入门Origin》 边际分布曲线图,其中包含散点图形,而在图的边际有着分布曲线图。在比较数据以查看多个变量之间是否存在关系时非常有用。 1.数据准备:为多列XY数…...

【Milvus】向量数据库pymilvus使用教程
以下是根据 Milvus 官方文档整理的详细 PyMilvus 使用教程,基于 Milvus 2.5.x 版本: PyMilvus 使用教程 目录 安装与环境准备连接 Milvus 服务数据模型基础概念创建集合(Collection)插入数据创建索引向量搜索删除操作完整示例注…...

React 生命周期函数详解
React 组件在其生命周期中有多个阶段,每个阶段都有特定的生命周期函数(Lifecycle Methods)。这些函数允许你在组件的不同阶段执行特定的操作。以下是 React 组件生命周期的主要阶段及其对应的生命周期函数,并结合了 React 16.3 的…...

第 26 场 蓝桥入门赛
2.对联【算法赛】 - 蓝桥云课 问题描述 大年三十,小蓝和爷爷一起贴对联。爷爷拿出了两副对联,每副对联都由 N 个“福”字组成,每个“福”字要么是正的(用 1 表示),要么是倒的(用 0 表示&#…...

组合(力扣77)
从这道题开始,我们正式进入回溯算法的学习。之前在二叉树中只是接触到了一丢丢,而这里我们将使用回溯算法解决很多经典问题。 那么这道题是如何使用回溯算法的呢?在讲回溯之前,先说明一下此题是如何递归的。毕竟回溯递归不分家&a…...

网络工程师 (22)网络协议
前言 网络协议是计算机网络中进行数据交换而建立的规则、标准或约定的集合,它规定了通信时信息必须采用的格式和这些格式的意义。 一、基本要素 语法:规定信息格式,包括数据及控制信息的格式、编码及信号电平等。这是协议的基础,确…...

Linux之文件IO前世今生
在 Linux之文件系统前世今生(一) VFS中,我们提到了文件的读写,并给出了简要的读写示意图,本文将分析文件I/O的细节。 一、Buffered I/O(缓存I/O)& Directed I/O(直接I/O&#…...

如何在Windows中配置MySQL?
MySQL是一个广泛使用的开源关系型数据库管理系统,它支持多种操作系统平台,其中包括Windows。无论是开发者进行本地开发,还是管理员为应用程序配置数据库,MySQL都是一个非常流行的选择。本篇文章将详细介绍如何在Windows操作系统中…...

Kafka 入门与实战
一、Kafka 基础 1.1 创建topic kafka-topics.bat --bootstrap-server localhost:9092 --topic test --create 1.2 查看消费者偏移量位置 kafka-consumer-groups.bat --bootstrap-server localhost:9092 --describe --group test 1.3 消息的生产与发送 #生产者 kafka-cons…...

数学知识学习1
1、数论 1质数判定 i<n/i优化O(sqrt(n)) bool is_prime(int n){if(n<2)return false;for(int i2;i<n/i;i){if(n%i0)return false;} true; } 分解质因数 i<n/i优化O(sqrt(n)) // 定义一个函数 divide,接收一个整数 n 作为参数,用于分解质…...

【AI日记】25.02.08
【AI论文解读】【AI知识点】【AI小项目】【AI战略思考】【AI日记】【读书与思考】【AI应用】 探索 AI 应用探索周二有个面试,明后天打算好好准备一下,我打算主要研究下 AI 如何在该行业赋能和应用,以及该行业未来的发展前景和公司痛点&#…...

Lecture8 | LPV VXGI SSAO SSDO
Review: Lecture 7 | Lecture 8 LPV (Light Propagation Volumes) Light Propagation Volumes(LPV)-孤岛惊魂CryEngine引进的技术 LPV做GI快|好 大体步骤: Step1.Generation of Radiance Point Set Scene Representation 生成辐射点集的场景表示:辐射…...

Java中实现定时锁屏的功能(可以指定时间执行)
Java中实现定时锁屏的功能(可以指定时间执行) 要在Java中实现定时锁屏的功能,可以使用java.util.Timer或java.util.concurrent.ScheduledExecutorService来调度任务,并通过调用操作系统的命令来执行锁屏。下面我将给出一个基本的…...

Java集合List详解(带脑图)
允许重复元素,有序。常见的实现类有 ArrayList、LinkedList、Vector。 ArrayList ArrayList 是在 Java 编程中常用的集合类之一,它提供了便捷的数组操作,并在动态性、灵活性和性能方面取得了平衡。如果需要频繁在中间插入和删除元素…...

[实验日志] VS Code 连接服务器上的 Python 解释器进行远程调试
目录 0. 前言 1. 环境 2. 准备工作 2.1 安装VS Code 2.2 安装插件 2.3 配置远程服务器 2.4 修改设置 2.5 打开远程调试窗口 3. 调试代码 3.1 输密码 3.2 打开服务器文件夹 3.3 配置Python环境 3.4 调试Python代码 补充:使用调试控制台,查看…...

(14)gdb 笔记(7):以日志记录的方式来调试多进程多线程程序,linux 命令 tail -f 实时跟踪日志
(44)以日志记录的方式来调试多进程多线程程序 : 这是老师的日志文件,可以用来模仿的模板: (45)实时追踪日志的 tail -f 命令: (46) 多种调试方法结合起来用 …...

Sentinel的安装和做限流的使用
一、安装 Release v1.8.3 alibaba/Sentinel GitHubA powerful flow control component enabling reliability, resilience and monitoring for microservices. (面向云原生微服务的高可用流控防护组件) - Release v1.8.3 alibaba/Sentinelhttps://github.com/alibaba/Senti…...

四柱预测学
图表 后天八卦 十二地支不仅代表了时间,还代表了方位。具体来说: 子:代表正北方丑寅:合起来代表东北方卯:代表正东方辰巳:合起来代表东南方午:代表正南方未申:合起来代表西南方酉:代表正西方戌亥:合起来代表西北方四季-五行-六神…...

【个人开发】macbook m1 Lora微调qwen大模型
本项目参考网上各类教程整理而成,为个人学习记录。 项目github源码地址:Lora微调大模型 项目中微调模型为:qwen/Qwen1.5-4B-Chat。 去年新发布的Qwen/Qwen2.5-3B-Instruct同样也适用。 微调步骤 step0: 环境准备 conda create --name fin…...

sqli-labs靶场实录(二): Advanced Injections
sqli-labs靶场实录: Advanced Injections Less21Less22Less23探测注入点 Less24Less25联合注入使用符号替代 Less25aLess26逻辑符号绕过and/or过滤双写and/or绕过 Less26aLess27Less27aLess28Less28aLess29Less30Less31Less32(宽字节注入)Less33Less34Le…...

Linux系统 环境变量
环境变量 写在前面概念查看环境变量main函数的参数argc & argvenv bash环境变量 写在前面 对于环境变量,本篇主要介绍基本概念及三四个环境变量 —— PATH、HOME、PWD。其中 PATH 作为 “ 敲门砖 ”,我们会更详细讲解;理解环境变量的全局…...

机器学习-线性回归(最大似然估计)
机器学习任务可以分为两类: 一类是样本的特征向量 𝒙 和标签 𝑦 之间存在未知的函数关系𝑦 h(𝒙),另一类是条件概率𝑝(𝑦|𝒙)服从某个未知分布。最小二乘法是属于第一类,…...

【信息系统项目管理师-案例真题】2017上半年案例分析答案和详解
更多内容请见: 备考信息系统项目管理师-专栏介绍和目录 文章目录 试题一【问题1】8 分【问题2】4 分【问题3】8 分【问题4】5 分试题二【问题1】10 分【问题2】8 分【问题3】6 分【问题4】5 分试题三【问题1】5 分【问题2】7 分【问题3】6 分【问题4】3 分试题一 阅读下列说明…...

CSP晋级组比赛生成文件夹与文件通用代码Python
快速生成文件夹与文件的脚本 import sys import osmyfiles sys.argv[1::] for f in myfiles:os.mkdir(f)os.system(f"touch {f}/{f}.in")os.system(f"touch {f}/{f}.out")os.system(f"touch {f}/{f}.cpp")with open("template.cpp",…...