当前位置: 首页 > news >正文

【深度学习_TensorFlow】过拟合

写在前面

过拟合与欠拟合


欠拟合: 是指在模型学习能力较弱,而数据复杂度较高的情况下,模型无法学习到数据集中的“一般规律”,因而导致泛化能力弱。此时,算法在训练集上表现一般,但在测试集上表现较差,泛化性能不佳。

过拟合: 是指模型在训练数据上表现很好,但在测试数据上表现不佳。这是由于模型过于复杂,记住了训练数据中的噪声和模式,而没有学到一般规则。本文将探讨过拟合问题以及其解决方法。

过拟合的解决方法


方法描述
增加训练数据
(More data)通过增加训练数据量,简单粗暴最有效,可以减少过拟合现象。
正则化
(regularization)在损失函数中添加一项,以惩罚模型的复杂度。常用的正则化方法包括L1正则化、L2正则化和dropout。
早停法
(Early stopping)在训练过程中,每次迭代后都会评估模型在验证集上的性能。如果性能在连续若干次迭代中没有提高,就停止训练。
数据增强
(Data augmentation)通过改变原有数据减少过拟合。例如,可以通过旋转、缩放等方式对图像数据进行增强。

写在中间

可以看到随着网络层数增加,模型变得复杂,过拟合现象变得愈发严重。接下来,我们将介绍一系列方法来帮助检测并抑制过拟合现象。

在这里插入图片描述

1. 交叉验证

增加数据集是最有效的方法,但是代价往往是昂贵的,所以要充分利用好现有的数据集。前面我们介绍了数据集需要划分为训练集和测试集,但我们为了挑选模型超参数和检测过拟合线性,一般需要将原来的训练集再次切分为新的训练集和验证集(validation set)。最终数据集被切分为 训练集、验证集、测试集。这三部分数据集的功能如下:

类别描述
训练集用于训练模型的参数,通过学习训练数据集来进行模型训练
验证集用于评估训练过程中的模型表现,调整模型的超参数
测试集用于评估最终训练好的模型在真实数据上的表现,测试验证模型的性能,评估模型的预测能力和泛化能力

验证集和测试集的区分

验证集使命:根据验证集的表现来调整模型的各种超参数的设置,提升模型的泛化能力。

测试集使命:就是检验模型的能力,其表现不能用来反馈模型的调整(就如你不能拿着期末考试原题来练习,否则期末高分就不能体现出你平时学习的真实状况),我们的办法就是从平常的练习题中抽取几道题组成验证集来检验你的能力。

这是一个将mnist手写数字识别测试集切分的例子,将6万张图像的前5万张划分为训练集,后1万张划分为验证集

(x, y), (x_test, y_test) = datasets.mnist.load_data()# 60k训练集切分为 50k训练集和 10k验证集
# (x_train, y_train), (x_val, y_val)
x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128)# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)

但是这样切分还是有局限性的,训练集只能在前5万张图像中出现,验证集只能在后1万张图像中出现,是否有方法让验证集能使用前5万张的图像呢?还真有,这就是标题所提及的**交叉验证:**我们将训练集的6万张图像划成n份,每训练取其中的 n - 1 份作为训练集,取 1 份作为验证集,而非固定前5万张为训练集,后 1 万张为验证集。

# 创建一个范围为60000的索引数组
idx = tf.range(60000)
# 随机打乱索引数组
idx = tf.random.shuffle(idx)
# 使用索引数组从x和y中获取训练集的样本
# 训练集的样本数量为50000
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
# 使用索引数组从x和y中获取验证集的样本
# 验证集的样本数量为10000
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])

也可以不用手动划分,在训练函数中增加参数,即可自动化操作

# 使用60k训练集对网络进行训练,设置训练周期数为6
# 设置验证集占比为0.1,即数据集中10%的数据将作为验证集
# 设置每个2个训练周期进行一次验证
network.fit(train_db, epochs=6, validation_split=0.1, validation_freq=2)

2. 正则化

之所以出现过拟合的现象,是因为模型太过复杂,这里通过限制网络参数的稀疏性来约束网络的实际容量,这种约束一般通过在损失函数上添加额外的参数稀疏性惩罚项实现,常用的正则化的方式有L0、L1、L2正则化,dropout正则化。


简单介绍

关于正则化的数学原理,本人也搞不明白,也就不滥竽充数了。

有关实现简单概括就是在损失函数中引入模型权重参数的L范数,使学习到的权重参数稀疏化。

正则化的方式可以手动实现,也可以调用API实现:其中手动实现主要在计算loss值的后面,调用API主要在创建层的时候。

import tensorflow as tf
from tensorflow.keras import layers, regularizers# 网络构建
# 这会在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 截取手动前向计算的代码
for step, (x, y) in enumerate(train_db):# 创建一个 GradientTape,用于记录计算过程with tf.GradientTape() as tape:x = tf.reshape(x, (-1, 28*28))  # [b, 28, 28] => [b, 784]out = network(x)  # [b, 784] => [b, 10]y_onehot = tf.one_hot(y, depth=10)  # [b] => [b, 10]# 使用交叉熵损失函数计算 lossloss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True))loss_regularization = []for p in network.trainable_variables:  # 重点在这里,遍历网络中所有的可训练参数(network.trainable_variables)。loss_regularization.append(tf.nn.l2_loss(p))  # # 对每个参数计算L2正则化项(tf.nn.l2_loss(p)),这会返回一个标量。loss_regularization = tf.reduce_sum(tf.stack(loss_regularization))  # # 将所有参数的L2正则化项求和,得到正则化损失loss_regularization。# 将损失函数定义为交叉熵损失和 L2 正则化损失的和loss = loss + 0.0001 * loss_regularization# 使用 tape 计算损失函数关于网络参数的梯度,并应用优化器进行反向传播更新参数grads = tape.gradient(loss, network.trainable_variables)optimizer.apply_gradients(zip(grads, network.trainable_variables))

正则化效果

在这里插入图片描述

Dropout


通过随机断开神经网络的连接,减少每次训练时实际参与计算的模型的参数量,但是在测试时,Dropout会恢复所有的连接,保证模型测试时获得最好的性能。

在这里插入图片描述

我们以层方式来实现以上功能

import tensorflow as tf
from tensorflow.keras import layers, regularizers# 网络构建network = Sequential([layers.Dense(256, activation='relu'),layers.Dropout(0.5),  # 有0.5的概率断开与下一层神经元的连接layers.Dense(128, activation='relu'),layers.Dropout(0.5),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()for step, (x, y) in enumerate(train_db):# 训练时with tf.GradientTape() as tape:out = network(x, training=True)# 测试时out = network(x, training=False)

3. Early stopping

早停法

那么如何选择合适的 Epoch 就提前停止训练(Early Stopping),避免出现过拟合现象呢?我们可以通过观察验证指标的变化,来预测最适合的 Epoch 可能的位置。具体地,对于分类问题,我们可以记录模型的验证准确率,并监控验证准确率的变化,当发现验证准确率连续𝑛个 Epoch 没有下降时,可以预测可能已经达到了最适合的 Epoch 附近,从而提前终止训练。

from tensorflow.keras.callbacks import EarlyStopping# 数据集读取···# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集losspatience=3,  # 当验证集loss在3个epoch内都没有改善则停止训练mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为maxverbose=1,  # 检测值改善时打印一条信息restore_best_weights=True  # 将权重恢复到最好的一个epoch)# 网络构建# 参数构建# 模型装配# 模型训练,添加参数network.fit(train_db, epochs=100,validation_data=val_db, validation_steps=10,callbacks=[early_stopping])

这里我们会对手写数字识别的代码再次进行修改,来使用上面提及的方法,你可以通过修改repeat的方式来复制数据集使训练数据增多,更改epochs的方式来增加训练次数,经过测试,这段代码在10 epochs 之后便达到了过拟合

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, regularizers
from tensorflow.keras.callbacks import EarlyStopping
# 处理每一张图像
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.x = tf.reshape(x, [28 * 28])y = tf.cast(y, dtype=tf.int32)y = tf.one_hot(y, depth=10)return x, y# 数据集读取
(x, y), (x_test, y_test) = datasets.mnist.load_data()# # 固定切分
# # 60k训练集切分为 50k训练集和 10k验证集
# # (x_train, y_train), (x_val, y_val)
# x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
# y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])# 交叉验证切分
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128).repeat(10)# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集losspatience=3,  # 当验证集loss在2个epoch内都没有改善则停止训练mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为maxverbose=1,  # 检测值改善时打印一条信息restore_best_weights=True  # 将权重恢复到最好的一个epoch)# 网络构建
# 正则化:在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
# Dropout:添加dropout层来随机断开连接
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dropout(0.5),layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dropout(0.5),layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
# 模型训练
network.fit(train_db, epochs=50,validation_data=val_db, validation_steps=10,callbacks=[early_stopping])
# 模型评估
print('模型评估:')
network.evaluate(test_db)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

相关文章:

【深度学习_TensorFlow】过拟合

写在前面 过拟合与欠拟合 欠拟合: 是指在模型学习能力较弱,而数据复杂度较高的情况下,模型无法学习到数据集中的“一般规律”,因而导致泛化能力弱。此时,算法在训练集上表现一般,但在测试集上表现较差&…...

uniapp授权小程序隐私弹窗效果demo(整理)

<template> <view class"dealBox"><view class"txtBox padding10"><!-- 查看协议 -->在您使用施工现场五星计划小程序之前&#xff0c;请仔细阅读<text class"goToPrivacy" click"handleOpenPrivacyContract&qu…...

c++学习之string实现

字符串 - C引用 (cplusplus.com)这里给出标准官方的string实现&#xff0c;可以看到设计还是较为复杂的&#xff0c;有成员函数&#xff0c;迭代器&#xff0c;修饰符&#xff0c;容量&#xff0c;元素访问&#xff0c;字符串操作等&#xff0c;将字符尽可能的需求都设计出来&a…...

kubevirt虚机创建svc通过NodePort的方式暴露端口

背景 存在kubevit存在的三个虚机&#xff1a; ubuntu-4tlg7 7d22h Running True ubuntu-7kgrk 7d22h Running True ubuntu-94kg2 7d22h Running True 网络没有做透传&#xff0c;pod也不是underlay网络想要通过NodePort方式暴露虚机22端口进行远程登录。 …...

Elasticsearch终端命令行用法大全

API作用使用场景curl localhost:9200/_cluster/health?pretty查看ES健康状态curl localhost:9200/_cluster/settings?pretty查看ES集群的设置其中persistent为永久设置&#xff0c;重启仍然有效&#xff1b;trainsient为临时设置&#xff0c;重启失效curl localhost:9200/_ca…...

nacos版本升级注意事项

背景&#xff1a;nacos版本升级&#xff0c;1.0.1升级到2.1.2&#xff0c;nacos主要用作配置中心 1 从官网下载新版本nacos压缩包 2 由于1.x到2.x版本数据结构发生变化&#xff0c;无法沿用旧的数据库&#xff0c;所以新建一个数据库实例&#xff0c;来保存具体的nacos配置信息…...

JavaScript作用域与作用域链

JavaScript作用域与作用域链 JavaScript的作用域和作用域链是理解这门语言的关键概念之一。作用域指的是变量和函数在程序中可被访问的范围。作用域链是由函数的嵌套关系决定的变量对象的链式结构。 静态作用域与动态作用域 JavaScript使用静态作用域&#xff0c;也称为词法…...

MQTT异常掉线原因

一、业务场景 我们在使用MQTT协议的时候&#xff0c;有些伙伴可能会遇到MQTT客户端频繁掉线、上线问题 二、原因分析及异常处理 1.原因&#xff1a;使用相同的clientID 方案&#xff1a;全局使用的clientID保证唯一性&#xff0c;可以采用UUID等方式 2.原因: 当前用户没有Top…...

重新理解百度智能云:写在大模型开放后的24小时

在这些回答背后共同折射出的一个现实是——大模型不再是一个单选题&#xff0c;而更是一个综合题。在这个新的时代帆船上&#xff0c;产品、服务、安全、开放等全部都需要成为必需品&#xff0c;甚至是从企业的落地层面来看&#xff0c;这些更是刚需品。 作者| 皮爷 出品|产…...

Stable Diffusion 提示词技巧

文章目录 背景介绍如何写好提示词提示词的语法正向提示词负向提示词 随着AI技术的不断发展&#xff0c;越来越多的新算法涌现出来&#xff0c;例如Stable Diffusion、Midjourney、Dall-E等。相较于传统算法如GAN和VAE&#xff0c;这些新算法在生成高分辨率、高质量的图片方面表…...

VS2019编译curl库

下载&#xff1a; curl-7.61.0.tar.gz 编译&#xff1a; 解压到一个文件下&#xff0c;然后右键以管理员权限运行buildconf.bat 编译x64的库使用的是x64 Native Tools Command Prompt for VS 2019 本机工具命令提示&#xff0c;如果想编译x86的库&#xff0c;可以选择x86 Nat…...

yolov5自定义模型训练三

经过11个小时cpu训练完如下 在runs/train/expx里存放训练的结果&#xff0c; 测试是否可以检测ok 网上找的这张识别效果不是很好&#xff0c;通过加大训练次数和数据集的话精度可以提升。 训练后的权重也可以用视频源来识别&#xff0c; python detect.py --source 0 # webca…...

服务器中了mkp勒索病毒该怎么办?勒索病毒解密,数据恢复

mkp勒索病毒算的上是一种比较常见的勒索病毒类型了。它的感染数量上也常年排在前几名的位置。所以接下来就由云天数据恢复中心的技术工程师来对mkp勒索病毒做一个分析&#xff0c;以及中招以后应该怎么办。 一&#xff0c;中了mkp勒索病毒的表现 桌面以及多个文件夹当中都有一封…...

Docker环境搭建Prometheus实验环境

环境&#xff1a; OS&#xff1a;Centos7 Docker: 20.10.9 - Community Centos部署Docker 【Kubernetes】Centos中安装Docker和Minikube_云服务器安装docker和minikube_DivingKitten的博客-CSDN博客 一、拉取Prometheus镜像 ## 拉取镜像 docker pull prom/prometheus ## 启动p…...

Python Qt学习(七)Listview

源代码&#xff1a; # -*- coding: utf-8 -*-# Form implementation generated from reading ui file qt_listview.ui # # Created by: PyQt5 UI code generator 5.15.9 # # WARNING: Any manual changes made to this file will be lost when pyuic5 is # run again. Do not…...

哈希表HashMap(基于vector和list)

C数据结构与算法实现&#xff08;目录&#xff09; 1 什么是HashMap&#xff1f; 我们这里要实现的HashMap接口不会超过标准库的版本&#xff08;是一个子集&#xff09;。 HashMap是一种键值对容器&#xff08;关联容器&#xff09;&#xff0c;又叫字典。 和其他容易一样…...

go中的函数

demo1:函数的几种定义方式 package mainimport ("errors""fmt" )/* 函数的用法 跟其他语言的区别&#xff1a;支持多个返回值*///函数定义方法1 func add(a, b int) int {return a b }//函数定义方法2 func add2(a, b int) (sun int) {sun a breturn s…...

小试 InsCode AI 创作助手

个人理解&#xff1a; 自ChatGPT新版现世&#xff0c;一直被视面替代人工工作的世大挑战&#xff0c;各类人工智能语言生成工目层出不穷&#xff0c;也在不断影响着我们日常的工作和生活 小试CSDN的InsCode AI&#xff1a; - 基本概念查询方便&#xff0c;与个人了解&…...

粉丝经验分享:13:00 开始的面试,13:06 就结束了,问题真是变态

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…...

SASS的@规则

1&#xff0c;import sass扩展了import导入&#xff0c;对于css&#xff0c;import导入在页面加载的时候去下载导入的外部文件&#xff0c;而sass的导入&#xff0c;在编译成css文件的时候就将外部的sass文件导入合并编译成一个css文件。 他支持同时导入多个文件&#xff1b;…...

【C++初阶】模拟实现优先级队列priority_queue

&#x1f466;个人主页&#xff1a;Weraphael ✍&#x1f3fb;作者简介&#xff1a;目前学习C和算法 ✈️专栏&#xff1a;C航路 &#x1f40b; 希望大家多多支持&#xff0c;咱一起进步&#xff01;&#x1f601; 如果文章对你有帮助的话 欢迎 评论&#x1f4ac; 点赞&#x1…...

如何为你的公司选择正确的AIGC解决方案?

如何为你的公司选择正确的AIGC解决方案&#xff1f; 摘要引言词汇解释&#xff08;详细版本&#xff09;详细介绍1. 确定需求2. 考虑技术能力3. 评估可行性4. 比较不同供应商 代码快及其注释注意事项知识总结 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&…...

Windows下将nginx等可执行文件添加为服务

Windows下将nginx等可执行文件添加为服务 为什么将可执行文件添加为服务&#xff1f;将可执行文件添加为服务的步骤步骤 1&#xff1a;下载和安装 Nginx步骤 2&#xff1a;添加为服务方法一&#xff1a;使用 Windows 自带的 sc 命令方法二&#xff1a;使用 NSSM&#xff08;Non…...

视觉SLAM14讲笔记-第4讲-李群与李代数

李代数的引出&#xff1a; 在优化问题中去解一个旋转矩阵&#xff0c;可能会有一些阻碍&#xff0c;因为它对加法导数不是很友好&#xff08;旋转矩阵加上一个微小偏移量可能就不是一个旋转矩阵&#xff09;&#xff0c;因为旋转矩阵本身还有一些约束条件&#xff0c;那样再求…...

浅析Redis(1)

一.Redis的含义 Redis可以用来作数据库&#xff0c;缓存&#xff0c;流引擎&#xff0c;消息队列。redis只有在分布式系统中才能充分的发挥作用&#xff0c;如果是单机程序&#xff0c;直接通过变量来存储数据是更优的选择。那我们知道进程之间是有隔离性的&#xff0c;那么re…...

【每日一题】2337. 移动片段得到字符串

【每日一题】2337. 移动片段得到字符串 2337. 移动片段得到字符串题目描述解题思路 2337. 移动片段得到字符串 题目描述 给你两个字符串 start 和 target &#xff0c;长度均为 n 。每个字符串 仅 由字符 ‘L’、‘R’ 和 ‘_’ 组成&#xff0c;其中&#xff1a; 字符 ‘L’…...

MySQL 数据库常用命令大全(详细)

文章目录 1. MySQL命令2. MySQL基础命令3. MySQL命令简介4. MySQL常用命令4.1 MySQL准备篇4.1.1 启动和停止MySQL服务4.1.2 修改MySQL账户密码4.1.3 MySQL的登陆和退出4.1.4 查看MySQL版本 4.2 DDL篇&#xff08;数据定义&#xff09;4.2.1 查询数据库4.2.2 创建数据库4.2.3 使…...

中国移动加大布局长三角,打造算力产业新高地

8月27日&#xff0c;以“数实融合算启未来”为主题的2023长三角算力发展大会在苏州举办&#xff0c;大会启动了长三角算力调度枢纽&#xff0c;携手各界推动算力产业高质量发展。 会上&#xff0c;移动云作为第一批算力资源提供方&#xff0c;与苏州市公共算力服务平台签订算力…...

话费、加油卡、视频会员等充值接口如何对接?

现在很多商家企业等发现与用户保持粘性是越来越难了&#xff0c;大多数的用户活跃度都很差&#xff0c;到底该怎么做才能改善这种情况呢&#xff1f; 那么我们需要做的就是投其所好&#xff0c;在与用户保持粘性的app或者积分商城中投入大家感兴趣的物品或者虚拟产品&#xff…...

服务器重启MongoDB无法启动

文章目录 服务器重启MongoDB无法启动背景规划实施 总结 服务器重启MongoDB无法启动 背景 数据库服务器的CPU接近告警值了&#xff0c;需要添加CPU资源&#xff0c;于是乎就在恰当的时间对服务器进行关机&#xff0c;待添加完资源后开机&#xff0c;这样就完成了CPU资源的添加…...