当前位置: 首页 > news >正文

TensorFlow深度学习实战(7)——分类任务详解

TensorFlow深度学习实战(7)——分类任务详解

    • 0. 前言
    • 1. 分类任务
      • 1.1 分类任务简介
      • 1.2 分类与回归的区别
    • 2. 逻辑回归
    • 3. 使用 TensorFlow 实现逻辑回归
    • 小结
    • 系列链接

0. 前言

分类任务 (Classification Task) 是机器学习中的一种监督学习问题,其目的是将输入数据(特征向量)映射到离散的类别标签。广泛应用于如文本分类、图像识别、垃圾邮件检测、医学诊断等多种领域。

1. 分类任务

1.1 分类任务简介

分类任务的目标是通过训练数据学习一个模型,使得对于新的输入数据能够预测其所属的类别。输入数据是模型的自变量,通常是特征向量 x = [ x 1 , x 2 , … , x n ] ) {x} = [x_1, x_2, \dots, x_n]) x=[x1,x2,,xn]),其中 n n n 是特征的维度,每个特征可能是连续值(如温度、年龄)或离散值(如颜色、性别)。输出是一个类别标签,表示每个输入数据点的所属类别,对于二分类任务,输出标签通常为 01;而对于多分类任务,标签的数量可以是多个类别,例如 0123 等。
根据类别的数量不同,可以将分类任务归为不同类型:

  • 二分类 (Binary Classification):输出只有两个类别,例如“是”与“否”
  • 多分类 (Multiclass Classification):输出包含多个类别标签,适用于每个样本属于多个可能类别中的一个的任务,例如“猫”、“狗”、“狮子”、“大象”等
  • 多标签分类 (Multilabel Classification):与传统的单一类别分类不同,每个样本可以同时属于多个类别

1.2 分类与回归的区别

回归和分类任务之间的区别:

  • 在分类任务中,数据被分成不同的类别,而在回归中,目标是根据给定的数据得到一个连续值。例如,识别手写数字的任务属于分类任务,所有的手写数字都属于 09 之间的某个数字;而根据不同的输入变量预测房屋价格则属于回归任务
  • 在分类任务中,模型的目标是找到分隔不同类别的决策边界;而在回归任务中,模型的目标是逼近一个适合输入输出关系的函数。

分类和回归任务的不同之处如下图所示。在分类中,我们需要找到分隔类别的线(或平面,或超平面)。在回归中,目标是找到一条(或一个平面,或一个超平面)拟合给定输入与输出关系的线。

分类与回归

2. 逻辑回归

逻辑回归 (Logistic regression) 用于确定事件发生的概率。通常,事件表示为分类的因变量。事件发生的概率使用 sigmoid (或logit )函数表示:
P ^ ( Y ^ = 1 ∣ X = x ) = 1 1 + e − ( b + W T x ) \hat P(\hat Y=1|X=x)=\frac 1{1+e^{-(b+W^Tx)}} P^(Y^=1∣X=x)=1+e(b+WTx)1
目标是估计权重 W = { w 1 , w 2 , . . . , w n } W=\{w_1,w_2,...,w_n\} W={w1,w2,...,wn} 和偏置项 b b b。在逻辑回归中,系数可以使用最大似然估计或随机梯度下降来估计。如果 p p p 是输入数据样本的总数,损失通常定义为交叉熵项:
l o s s = ∑ i = 1 p Y i l o g ( Y ^ i ) + ( 1 − Y i ) l o g ( 1 − Y ^ i ) loss=\sum_{i=1}^pY_ilog(\hat Y_i)+(1-Y_i)log(1-\hat Y_i) loss=i=1pYilog(Y^i)+(1Yi)log(1Y^i)
逻辑回归用于分类问题。例如,在分析医疗数据时,我们可以使用逻辑回归来分类一个人是否患有癌症。如果输出的分类变量具有两个或多个,可以使用多分类逻辑回归。对于多分类逻辑回归,交叉熵损失函数可以改写为:
l o s s = ∑ i = 1 p ∑ j = 1 k Y i j l o g Y ^ i j loss=\sum_{i=1}^p\sum_{j=1}^kY_{ij}log\hat Y_{ij} loss=i=1pj=1kYijlogY^ij
其中 k k k 是类别总数。了解了逻辑回归的原理后,接下来,将其应用于具体实践中。

3. 使用 TensorFlow 实现逻辑回归

接下来,使用 TensorFlow 实现逻辑回归,对 MNIST 手写数字进行分类。MNIST 数据集包含手写数字的图像,每个图像都有一个标签值(介于 09 之间)标注图像中的数字值。因此,属于多类别分类问题。
为了实现逻辑回归,构建一个仅包含一个全连接层的模型。输出中的每个类别由一个神经元表示,由于我们有 10 个类别,输出层的神经元数为 10。逻辑回归中使用的概率函数类似于 sigmoid 激活函数,因此,模型使用 sigmoid 激活。接下来,构建模型。

(1) 首先,导入所需库。由于全连接层接收的输入为一维数据,因此使用 Flatten 层,用于将 MNIST 数据集中的 28 x 28 二维输入图像调整为一个包含 784 个元素的一维数组:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow.keras as K
from tensorflow.keras.layers import Dense, Flatten

(2)tensorflow.keras 数据集中获取 MNIST 输入数据:

((train_data, train_labels),(test_data, test_labels)) = tf.keras.datasets.mnist.load_data()

(3) 对数据进行预处理。对图像进行归一化,MNIST 数据集的图像是灰度图像,每个像素的强度值介于 0255 之间,将其除以 255,使数值范围在 01 之间:

train_data = train_data/np.float32(255)
train_labels = train_labels.astype(np.int32)  
test_data = test_data/np.float32(255)
test_labels = test_labels.astype(np.int32)

(4) 定义模型,模型只有一个具有 10 个神经元的 Dense 层,输入大小为 784,从模型摘要的输出中可以看到,只有 Dense 层具有可训练的参数:

model = K.Sequential([# Dense(64,  activation='relu'),# Dense(32,  activation='relu'),Flatten(input_shape=(28, 28)),Dense(10, activation='sigmoid')
])
print(model.summary())

模型架构

(5) 因为标签是整数值,因此使用 SparseCategoricalCrossentropy 损失函数,设置 logits 参数为 True。选择 Adam 优化器,此外,定义准确率作为在训练过程中需要记录的指标。模型训练 50epochs,使用 80:20 的比例拆分训练-验证集:

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
history = model.fit(x=train_data,y=train_labels, epochs=50, verbose=1, validation_split=0.2)

(6) 绘制损失曲线观察模型性能表现。可以看到随着 epoch 的增加,训练损失降低的同时,验证损失逐渐增加,因此模型出现过拟合,可以通过添加隐藏层来改善模型性能:

plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

训练过程监测

(7) 为了更好地理解结果,构建两个实用函数,用于可视化手写数字以及模型输出的 10 个神经元的概率:

predictions = model.predict(test_data)def plot_image(i, predictions_array, true_label, img):true_label, img = true_label[i], img[i]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(img, cmap=plt.cm.binary)predicted_label = np.argmax(predictions_array)if predicted_label == true_label:color = 'blue'else:color = 'red'plt.xlabel("Pred {} Conf: {:2.0f}% True ({})".format(predicted_label,100*np.max(predictions_array),true_label),color=color)def plot_value_array(i, predictions_array, true_label):true_label = true_label[i]plt.grid(False)plt.xticks(range(10))plt.yticks([])thisplot = plt.bar(range(10), predictions_array, color="#777777")plt.ylim([0, 1])predicted_label = np.argmax(predictions_array)thisplot[predicted_label].set_color('red')thisplot[true_label].set_color('blue')

(8) 绘制预测结果:

i = 56
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_data)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

左侧的图像是手写数字图像,图像下方显示了预测的标签、预测的置信度以及真实标签。右侧的图像显示了 10 个神经元输出的概率(逻辑输出),可以看到代表数字 4 的神经元具有最高的概率:

训练结果

(9) 为了保持逻辑回归的特性,以上代码仅使用了一个包含 sigmoid 激活函数的 Dense 层。为了获得更好的性能,可以添加 Dense 层并使用 softmax 作为最终的激活函数,以下模型在验证数据集上能够达到 97% 的准确率:

better_model = K.Sequential([Flatten(input_shape=(28, 28)),Dense(128,  activation='relu'),#Dense(64,  activation='relu'),Dense(10, activation='softmax')
])
better_model.summary()better_model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])history = better_model.fit(x=train_data,y=train_labels, epochs=10, verbose=1, validation_split=0.2)plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()predictions = better_model.predict(test_data)
i = 0
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plot_image(i, predictions[i], test_labels, test_data)
plt.subplot(1,2,2)
plot_value_array(i, predictions[i],  test_labels)
plt.show()

模型预测

我们可以尝试添加更多隐藏层,或者修改隐藏层中神经元的数量,或者修改优化器,以更好地理解这些参数对模型性能的影响。

小结

分类任务是机器学习中最常见的任务之一,广泛应用于各个领域。成功的分类任务不仅需要选择合适的算法,还需要对数据进行深入的预处理和特征工程。在本节中,我们首先介绍了分类任务及其与回归任务的区别,然后介绍了用于分类任务的逻辑回归技术,并使用 TensorFlow 实现了逻辑回归模型。

系列链接

TensorFlow深度学习实战(1)——神经网络与模型训练过程详解
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
TensorFlow深度学习实战(3)——深度学习中常用激活函数详解
TensorFlow深度学习实战(4)——正则化技术详解
TensorFlow深度学习实战(5)——神经网络性能优化技术详解
TensorFlow深度学习实战(6)——回归分析详解

相关文章:

TensorFlow深度学习实战(7)——分类任务详解

TensorFlow深度学习实战(7)——分类任务详解 0. 前言1. 分类任务1.1 分类任务简介1.2 分类与回归的区别 2. 逻辑回归3. 使用 TensorFlow 实现逻辑回归小结系列链接 0. 前言 分类任务 (Classification Task) 是机器学习中的一种监督学习问题,…...

动态规划问题——青蛙跳台阶案例分析

问题描述: 一只青蛙要跳上n级台阶,它每次可以跳 1级或者2级。问:青蛙有多少种不同的跳法可以跳完这些台阶? 举个例子: 假设台阶数 n 3 ,我们来看看青蛙有多少种跳法。 可能的跳法: 1. 跳1级…...

element-ui使用el-table,保留字段前的空白

项目名称项目编号1、XXXXX1111111111111111111 1.1 XXXXX11111111111111222222222 如上表格中&#xff0c;实现项目名称字段1.1前空白的效果。 从JAVA返回的数据带有空白&#xff0c;即数据库中插入的数据带有空白。 原先写法&#xff1a; <el-table><el-tabl…...

kamailio中路由模块汇总

功能模块描述请求路由 (request_route)主要处理进入的SIP请求&#xff0c;包含初步检查、NAT检测、CANCEL请求处理、重传处理等。处理通过REQINIT、NATDETECT、RELAY等子模块的调用。CANCEL处理对CANCEL请求进行处理&#xff0c;包括更新对话状态并检查事务。如果事务检查通过&…...

如何使用 DeepSeek 搭建本地知识库

使用 DeepSeek 搭建本地知识库可以帮助您高效管理和检索本地文档、数据或知识资源。以下是详细的步骤指南&#xff1a; 1. 准备工作 (1) 安装 DeepSeek 确保您的系统已安装 Python 3.8 或更高版本。使用 pip 安装 DeepSeek&#xff1a; bash pip install deepseek (2) 准备…...

网络HTTP详细讲解

学习目标 什么是HTTPHTTP的请求和响应常见的HTTP状态码HTTP的安全性 什么是HTTP&#xff1f;HTTP的请求和响应&#xff0c;常见的HTTP状态码&#xff0c;HTTP的安全性 什么是HTTP HTTP&#xff08;HyperText Transfer Protocol&#xff0c;超文本传输协议&#xff09;是一种用…...

《Origin画百图》之边际分布曲线图

《Origin画百图》第六集——边际分布曲线图 入门操作可看《30秒&#xff0c;带你入门Origin》 边际分布曲线图&#xff0c;其中包含散点图形&#xff0c;而在图的边际有着分布曲线图。在比较数据以查看多个变量之间是否存在关系时非常有用。 1.数据准备&#xff1a;为多列XY数…...

【Milvus】向量数据库pymilvus使用教程

以下是根据 Milvus 官方文档整理的详细 PyMilvus 使用教程&#xff0c;基于 Milvus 2.5.x 版本&#xff1a; PyMilvus 使用教程 目录 安装与环境准备连接 Milvus 服务数据模型基础概念创建集合&#xff08;Collection&#xff09;插入数据创建索引向量搜索删除操作完整示例注…...

React 生命周期函数详解

React 组件在其生命周期中有多个阶段&#xff0c;每个阶段都有特定的生命周期函数&#xff08;Lifecycle Methods&#xff09;。这些函数允许你在组件的不同阶段执行特定的操作。以下是 React 组件生命周期的主要阶段及其对应的生命周期函数&#xff0c;并结合了 React 16.3 的…...

第 26 场 蓝桥入门赛

2.对联【算法赛】 - 蓝桥云课 问题描述 大年三十&#xff0c;小蓝和爷爷一起贴对联。爷爷拿出了两副对联&#xff0c;每副对联都由 N 个“福”字组成&#xff0c;每个“福”字要么是正的&#xff08;用 1 表示&#xff09;&#xff0c;要么是倒的&#xff08;用 0 表示&#…...

组合(力扣77)

从这道题开始&#xff0c;我们正式进入回溯算法的学习。之前在二叉树中只是接触到了一丢丢&#xff0c;而这里我们将使用回溯算法解决很多经典问题。 那么这道题是如何使用回溯算法的呢&#xff1f;在讲回溯之前&#xff0c;先说明一下此题是如何递归的。毕竟回溯递归不分家&a…...

网络工程师 (22)网络协议

前言 网络协议是计算机网络中进行数据交换而建立的规则、标准或约定的集合&#xff0c;它规定了通信时信息必须采用的格式和这些格式的意义。 一、基本要素 语法&#xff1a;规定信息格式&#xff0c;包括数据及控制信息的格式、编码及信号电平等。这是协议的基础&#xff0c;确…...

Linux之文件IO前世今生

在 Linux之文件系统前世今生&#xff08;一&#xff09; VFS中&#xff0c;我们提到了文件的读写&#xff0c;并给出了简要的读写示意图&#xff0c;本文将分析文件I/O的细节。 一、Buffered I/O&#xff08;缓存I/O&#xff09;& Directed I/O&#xff08;直接I/O&#…...

如何在Windows中配置MySQL?

MySQL是一个广泛使用的开源关系型数据库管理系统&#xff0c;它支持多种操作系统平台&#xff0c;其中包括Windows。无论是开发者进行本地开发&#xff0c;还是管理员为应用程序配置数据库&#xff0c;MySQL都是一个非常流行的选择。本篇文章将详细介绍如何在Windows操作系统中…...

Kafka 入门与实战

一、Kafka 基础 1.1 创建topic kafka-topics.bat --bootstrap-server localhost:9092 --topic test --create 1.2 查看消费者偏移量位置 kafka-consumer-groups.bat --bootstrap-server localhost:9092 --describe --group test 1.3 消息的生产与发送 #生产者 kafka-cons…...

数学知识学习1

1、数论 1质数判定 i<n/i优化O(sqrt(n)) bool is_prime(int n){if(n<2)return false;for(int i2;i<n/i;i){if(n%i0)return false;} true; } 分解质因数 i<n/i优化O(sqrt(n)) // 定义一个函数 divide&#xff0c;接收一个整数 n 作为参数&#xff0c;用于分解质…...

【AI日记】25.02.08

【AI论文解读】【AI知识点】【AI小项目】【AI战略思考】【AI日记】【读书与思考】【AI应用】 探索 AI 应用探索周二有个面试&#xff0c;明后天打算好好准备一下&#xff0c;我打算主要研究下 AI 如何在该行业赋能和应用&#xff0c;以及该行业未来的发展前景和公司痛点&#…...

Lecture8 | LPV VXGI SSAO SSDO

Review: Lecture 7 | Lecture 8 LPV (Light Propagation Volumes) Light Propagation Volumes(LPV)-孤岛惊魂CryEngine引进的技术 LPV做GI快|好 大体步骤&#xff1a; Step1.Generation of Radiance Point Set Scene Representation 生成辐射点集的场景表示&#xff1a;辐射…...

Java中实现定时锁屏的功能(可以指定时间执行)

Java中实现定时锁屏的功能&#xff08;可以指定时间执行&#xff09; 要在Java中实现定时锁屏的功能&#xff0c;可以使用java.util.Timer或java.util.concurrent.ScheduledExecutorService来调度任务&#xff0c;并通过调用操作系统的命令来执行锁屏。下面我将给出一个基本的…...

Java集合List详解(带脑图)

允许重复元素&#xff0c;有序。常见的实现类有 ArrayList、LinkedList、Vector。 ArrayList ArrayList 是在 Java 编程中常用的集合类之一&#xff0c;它提供了便捷的数组操作&#xff0c;并在动态性、灵活性和性能方面取得了平衡。如果需要频繁在中间插入和删除元素&#xf…...

Zustand 状态管理库:极简而强大的解决方案

Zustand 是一个轻量级、快速和可扩展的状态管理库&#xff0c;特别适合 React 应用。它以简洁的 API 和高效的性能解决了 Redux 等状态管理方案中的繁琐问题。 核心优势对比 基本使用指南 1. 创建 Store // store.js import create from zustandconst useStore create((set)…...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下&#xff0c;无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作&#xff0c;还是游戏直播的画面实时传输&#xff0c;低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架&#xff0c;凭借其灵活的编解码、数据…...

全球首个30米分辨率湿地数据集(2000—2022)

数据简介 今天我们分享的数据是全球30米分辨率湿地数据集&#xff0c;包含8种湿地亚类&#xff0c;该数据以0.5X0.5的瓦片存储&#xff0c;我们整理了所有属于中国的瓦片名称与其对应省份&#xff0c;方便大家研究使用。 该数据集作为全球首个30米分辨率、覆盖2000–2022年时间…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍

文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结&#xff1a; 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析&#xff1a; 实际业务去理解体会统一注…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar&#xff1a;依赖注入与仓储模式实践 在 C# 的应用开发中&#xff0c;数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护&#xff0c;许多开发者会选择成熟的 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;SqlSugar 就是其中备受…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

佰力博科技与您探讨热释电测量的几种方法

热释电的测量主要涉及热释电系数的测定&#xff0c;这是表征热释电材料性能的重要参数。热释电系数的测量方法主要包括静态法、动态法和积分电荷法。其中&#xff0c;积分电荷法最为常用&#xff0c;其原理是通过测量在电容器上积累的热释电电荷&#xff0c;从而确定热释电系数…...

Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)

引言 在人工智能飞速发展的今天&#xff0c;大语言模型&#xff08;Large Language Models, LLMs&#xff09;已成为技术领域的焦点。从智能写作到代码生成&#xff0c;LLM 的应用场景不断扩展&#xff0c;深刻改变了我们的工作和生活方式。然而&#xff0c;理解这些模型的内部…...

脑机新手指南(七):OpenBCI_GUI:从环境搭建到数据可视化(上)

一、OpenBCI_GUI 项目概述 &#xff08;一&#xff09;项目背景与目标 OpenBCI 是一个开源的脑电信号采集硬件平台&#xff0c;其配套的 OpenBCI_GUI 则是专为该硬件设计的图形化界面工具。对于研究人员、开发者和学生而言&#xff0c;首次接触 OpenBCI 设备时&#xff0c;往…...