【机器学习实战】kaggle 欺诈检测---使用生成对抗网络(GAN)解决欺诈数据中正负样本极度不平衡问题
【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题
https://blog.csdn.net/2302_79308082/article/details/145177242
本篇文章是基于上次文章中提到的对抗生成网络,通过对抗生成网络生成少数类样本,平衡欺诈数据中正类样本极少的问题。
本人主页:机器学习小小白
机器学习专栏:机器学习实战
PyTorch入门专栏:PyTorch入门
深度学习实战:深度学习
ok,话不多说,我们进入正题吧
1. 引言
生成对抗网络(Generative Adversarial Networks,简称GAN)是由Ian Goodfellow等人于2014年提出的一种深度学习模型。它在计算机视觉、自然语言处理、音频生成等领域得到了广泛应用。GAN的核心思想是通过两个神经网络之间的博弈关系来生成新的、仿真的数据。自从GAN提出以来,它已经成为生成模型领域的突破性进展,深刻改变了生成式模型的研究和应用。
2. GAN的基本原理
生成对抗网络的结构包括两个主要部分:生成器(Generator)和判别器(Discriminator)。这两个网络分别充当“对手”,并在训练过程中互相博弈:
-
生成器(Generator):该网络的目的是通过学习数据分布来生成尽可能接近真实数据的虚假样本。生成器从一个随机的噪声(通常是高维的向量)出发,逐步生成样本。
-
判别器(Discriminator):该网络的任务是判断一个样本是真实的(来自训练数据)还是虚假的(来自生成器)。判别器输出一个概率值,表示输入样本为真实数据的概率。
3. GAN的训练过程
GAN的训练过程是一个“博弈”过程,生成器和判别器不断互相对抗,从而提升各自的性能。这个过程可以通过以下的数学公式来表示:
- 判别器的目标:判别器的目标是最大化其对于真实数据的判断概率(即预测为1的概率),同时最小化对生成数据的错误分类(即预测为0的概率)。可以通过以下的交叉熵损失函数表示:
其中:
-
是从真实数据分布中采样的数据。
-
是生成器生成的样本,
是从潜在空间中采样的噪声。
-
是判别器对样本
的判别输出,表示其为真实数据的概率。
-
生成器的目标:生成器的目标是使判别器无法区分生成数据与真实数据,因此它通过最大化判别器对生成数据为真实的概率来进行训练:
-
-
其中:
是生成器生成的虚假样本,
是判别器对生成样本的输出,表示其为真实数据的概率。
在训练过程中,生成器和判别器会交替优化这两个损失函数。理想的结果是生成器能够生成与真实数据分布相似的样本,而判别器则无法有效地区分生成数据与真实数据。
4. GAN的应用
GAN具有强大的生成能力,广泛应用于多个领域,以下是一些典型的应用场景:
-
图像生成:GAN可以用于生成高度逼真的图像,如人脸、风景或艺术作品。典型的例子包括DeepArt和StyleGAN,后者能够生成几乎无法与真实人脸区分的图像。
-
图像到图像的转换:例如,利用GAN进行图像风格转换(如将照片转化为油画风格)、超分辨率重建(如提高图像的分辨率)、图像修复(如填补丢失部分)等任务。
-
文本生成:结合自然语言处理技术,GAN也可用于生成文本数据,如诗歌、故事生成等,尤其是文本生成和对话系统中的对抗训练。
-
音频生成:GAN被广泛应用于音频生成,如音乐生成、语音合成等。
-
数据增强:GAN可以用于数据增强,特别是在医疗图像领域,生成具有一定变异的图像样本,以增强训练数据集。
-
模型训练中的对抗样本生成:GAN可以生成对抗样本,即通过对训练数据进行微小扰动,生成能够误导模型的样本,这对提升模型的鲁棒性非常重要。
5. GAN的变种
GAN作为一种框架,已经发展出了多种变种,以满足不同应用的需求。以下是几种常见的GAN变种:
-
CGAN(Conditional GAN):在生成器和判别器中都加入了条件变量,使得生成的样本可以根据某些条件(如标签信息)进行控制。
-
WGAN(Wasserstein GAN):解决了传统GAN在训练过程中可能出现的梯度消失和模式崩溃问题。WGAN使用了Wasserstein距离作为生成器和判别器的损失函数。
-
DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来构建生成器和判别器,增强了GAN在图像生成任务中的表现。
-
CycleGAN:用于无监督学习场景,特别是在图像到图像的转换中,例如将一张照片转换成另一种风格(如马到斑马转换)。
6. 使用生成对抗网络(GAN)生成欺诈数据中少数类数据
1. 数据预处理与特征提取
import pandas as pd
import numpy as nptrain_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/train.csv')
test_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/test.csv')def time_feature(df):df['Time'] = pd.to_datetime(df['Time'], unit='s') # 将时间戳转为 datetime 格式# 提取时间特征df['hour'] = df['Time'].dt.hourdf['minute'] = df['Time'].dt.minute return df train_df = time_feature(train_df)
test_df = time_feature(test_df)
在欺诈检测任务中,时间特征(如交易发生的小时和分钟)通常是重要的,因为欺诈交易往往具有不同的时间模式。例如,欺诈交易可能集中在某些特定的时间段。
- 这里我们通过
pd.to_datetime()
将Time
列从Unix时间戳格式转换为日期时间格式。然后,我们提取了小时和分钟作为新的特征,用于训练模型。
train_feature = train_df.drop(columns=['id','IsFraud','Time'])
test_feature = test_df.drop(columns=['id','Time'])label = train_df['IsFraud']
train_feature
是用于训练的特征数据,删除了 id
, IsFraud
和 Time
列。IsFraud
是标签列,表示交易是否为欺诈交易;而 id
和 Time
列不包含有用的特征信息,因此可以去掉。
2. 标准化数据
from sklearn.preprocessing import StandardScaler# 标准化特征数据
scaler = StandardScaler()
train_feature_scaled = scaler.fit_transform(train_feature)
-
标准化(Standardization)是机器学习中常用的预处理步骤。它通过减去均值并除以标准差,使特征数据具有零均值和单位方差。标准化能够加速模型的收敛过程,尤其是在使用像神经网络这样的梯度优化模型时。
-
这里使用
StandardScaler
来对训练数据进行标准化,以确保所有特征在同一个量级。
3. 生成器与判别器的构建
生成器(Generator)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Inputdef build_generator(latent_dim, input_dim):model = Sequential()model.add(Input(shape=(latent_dim,))) # 使用 Input 层来指定输入维度model.add(Dense(256))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(input_dim, activation='tanh')) # 输出层与原数据同维度return model
生成器(Generator)是GAN的核心部分,它通过接收随机噪声向量(潜在空间中的点),然后经过一系列的全连接层和激活函数,生成与原始数据分布相似的虚假数据。
- 在此,我们使用了
LeakyReLU
激活函数,它允许梯度通过负半轴流动,解决了传统ReLU可能出现的“死神经元”问题。BatchNormalization
用于加速网络的训练,并帮助改善模型的稳定性。
判别器(Discriminator)
def build_discriminator(input_dim):model = Sequential()model.add(Input(shape=(input_dim,))) # 使用 Input 层来指定输入维度model.add(Dense(1024))model.add(LeakyReLU(0.2))model.add(Dense(512))model.add(LeakyReLU(0.2))model.add(Dense(256))model.add(LeakyReLU(0.2))model.add(Dense(1, activation='sigmoid')) # 输出真假判定return model
判别器(Discriminator)的任务是判断输入数据是真实的还是由生成器生成的。它是一个二分类模型,输出是一个概率值,表示输入数据为真实的概率。
- 这里使用
sigmoid
激活函数,输出一个概率值。判别器学习将真实数据和生成数据区分开来。
4. GAN模型的组合与训练
def build_gan(generator, discriminator):discriminator.trainable = False # 在训练GAN时冻结判别器model = Sequential()model.add(generator)model.add(discriminator)return model# 定义优化器
optimizer = Adam()# 定义输入维度和潜在维度
latent_dim = 100 # 随机噪声的维度
input_dim = 31 # 输入数据的维度,例如欺诈检测数据的特征数# 创建并编译模型
generator = build_generator(latent_dim, input_dim)
discriminator = build_discriminator(input_dim)
gan = build_gan(generator, discriminator)# 编译判别器和GAN模型
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
-
生成对抗训练(Adversarial Training)是GAN的关键。生成器和判别器在一个博弈过程中互相优化。在训练过程中,生成器通过“欺骗”判别器来优化其生成数据的能力,而判别器则不断学习区分真实和生成数据。
-
在训练过程中,我们冻结判别器的参数,只训练生成器,这样可以避免在训练生成器时更新判别器的权重。
5. GAN训练函数
def train_gan(generator, discriminator, gan, fraud_data_scaled, epochs=10000, batch_size=64):valid = np.ones((batch_size, 1)) # 真数据标签fake = np.zeros((batch_size, 1)) # 假数据标签for epoch in range(epochs):# 随机选择真实欺诈数据idx = np.random.randint(0, fraud_data_scaled.shape[0], batch_size)real_data = fraud_data_scaled[idx]# 生成虚拟数据noise = np.random.normal(0, 1, (batch_size, latent_dim))generated_data = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(real_data, valid)d_loss_fake = discriminator.train_on_batch(generated_data, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = gan.train_on_batch(noise, valid)# 输出训练过程的损失if epoch % 1000 == 0:print(f'{epoch}/{epochs} [D loss: {d_loss[0]}] [G loss: {g_loss}]')
-
训练过程:在每个训练周期中,首先更新判别器的权重(通过训练它区分真实数据和生成数据),然后训练生成器(通过训练它欺骗判别器)。
-
损失函数:我们使用了
binary_crossentropy
损失函数,它用于二分类任务。在判别器的训练中,我们分别计算真实数据和生成数据的损失,然后平均得到判别器的总损失。生成器的损失则是通过GAN模型进行计算的。
6. 生成虚拟数据
def generate_fake_data(generator, num_samples):noise = np.random.normal(0, 1, (num_samples, latent_dim)) # 随机噪声generated_data = generator.predict(noise) # 生成虚拟数据# 将生成的数据转换回原始空间generated_data_original = scaler.inverse_transform(generated_data)# 获取原始负样本数据的列名(去除 'id', 'IsFraud', 'Time' 列)feature_columns = [col for col in train_df.columns if col not in ['id', 'IsFraud', 'Time']]# 将生成的数据与原始负样本数据(即非欺诈数据)结合,作为新的训练数据augmented_data = np.concatenate([train_df[train_df['IsFraud'] == 0].drop(columns=['id', 'IsFraud', 'Time']),generated_data_original], axis=0)augmented_label = np.concatenate([np.zeros(train_df[train_df['IsFraud'] == 0].shape[0]), np.ones(generated_data_original.shape[0])], axis=0)# 创建包含生成数据和标签的 DataFrameaugmented_df = pd.DataFrame(augmented_data, columns=feature_columns)augmented_df['IsFraud'] = augmented_labelreturn augmented_df
在这个函数中,我们使用训练好的生成器来生成新的虚拟欺诈数据,并将它们与真实的非欺诈数据结合,以增强数据集。然后,我们通过逆标准化将生成的数据转换回原始数据空间。
本次例子为了缩短训练时间,只生成了100条虚拟的正样本数据。
相关文章:
【机器学习实战】kaggle 欺诈检测---使用生成对抗网络(GAN)解决欺诈数据中正负样本极度不平衡问题
【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题https://blog.csdn.net/2302_79308082/article/details/145177242 本篇文章是基于上次文章中提到的对抗生成网络,通过对抗生成网络生成少数类样本,平衡欺诈数据中正类样本极…...
android wifi framework与wpa_supplicant的交互
android frmework直接与wpa_supplicant进行交互,使用aidl或者hidl 二、事件 framework注册事件的地方: packages/modules/Wifi/service/java/com/android/server/wifi/SupplicantStaIfaceCallbackImpl.java class SupplicantStaIfaceCallbackImpl exte…...

初学stm32 --- flash模仿eeprom
目录 STM32内部FLASH简介 内部FLASH构成(F1) FLASH读写过程(F1) 闪存的读取 闪存的写入 内部FLASH构成(F4 / F7 / H7) FLASH读写过程(F4 / F7 / H7) 闪存的读取 闪存的写入 …...

使用C语言实现栈的插入、删除和排序操作
栈是一种后进先出(LIFO, Last In First Out)的数据结构,这意味着最后插入的元素最先被删除。在C语言中,我们可以通过数组或链表来实现栈。本文将使用数组来实现一个简单的栈,并提供插入(push)、删除(pop)以及排序(这里采用一种简单的排序方法,例如冒泡排序)的操作示…...

C语言程序环境和预处理详解
本章重点: 程序的翻译环境 程序的执行环境 详解:C语言程序的编译链接 预定义符号介绍 预处理指令 #define 宏和函数的对比 预处理操作符#和##的介绍 命令定义 预处理指令 #include 预处理指令 #undef 条件编译 程序的翻译环境和执行环…...

基于机器学习随机森林算法的个人职业预测研究
1.背景调研 随着信息技术的飞速发展,特别是大数据和云计算技术的广泛应用,各行各业都积累了大量的数据。这些数据中蕴含着丰富的信息和模式,为利用机器学习进行职业预测提供了可能。机器学习算法的不断进步,如深度学习、强化学习等…...

三种文本相似计算方法:规则、向量与大模型裁判
文本相似计算 项目背景 目前有众多工作需要评估字符串之间的相似(相关)程度: 比如,RAG 智能问答系统文本召回阶段需要计算用户文本与文本库内文本的相似分数,返回前TopK个候选文本。 在评估大模型生成的文本阶段,也需要评估…...
Python语言的计算机基础
Python语言的计算机基础 绪论 在当今信息技术飞速发展的时代,编程已经成为了一种必备技能。Python凭借其简洁、易读和强大的功能,逐渐成为初学者学习编程的首选语言。本文将以Python语言为基础,探讨计算机科学的基本概念,并帮助…...

Dify应用-工作流
目录 DIFY 工作流参考 DIFY 工作流 2025-1-15 老规矩感谢参考文章的作者,避免走弯路。 2025-1-15 方便容易上手 在dify的一个桌面上,添加多个节点来完成一个任务。 每个工作流必须有一个开始和结束节点。 节点之间用线连接即可。 每个节点可以有输入和输出 输出类型有,字符串,…...
02.02、返回倒数第 k 个节点
02.02、[简单] 返回倒数第 k 个节点 1、题目描述 实现一种算法,找出单向链表中倒数第 k 个节点。返回该节点的值。 2、题解思路 本题的关键在于使用双指针法,通过两个指针(fast 和 slow),让 fast 指针比 slow 指针…...

Linux手写FrameBuffer任意引脚驱动spi屏幕
一、硬件设备 开发板:香橙派 5Plus,cpu:RK3588,带有 40pin 外接引脚。 屏幕:SPI 协议 0.96 寸 OLED。 二、需求 主要是想给板子增加一个可视化的监视器,并且主页面可调。 平时跑个模型或者服务,…...

怎么修复损坏的U盘?而且不用格式化的方式!
当你插入U盘时,若电脑弹出“需要格式化才能使用”提示,且无法打开或读取其中的数据,说明U盘极有可能已经损坏。除此之外,若电脑在连接U盘后显示以下信息,也可能意味着U盘出现问题,需要修复损坏的U盘&#x…...
语音技术在播客领域的应用(2)
播客是以语音为主,各种基于AI 的语音技术在播客领域十分重要。 语音转文本 Whisper Whisper 是OpenAI 推出的开源语音辨识工具,可以把音档转成文字,支援超过50 种语言。这款工具是基于68 万小时的训练资料,其中包含11.7 万小时的…...

【Linux】应用层自定义协议与序列化
🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 应用层 🦋 再谈 "协议"🦋 网络版计算器🦋 序列化 和 反序列化 二:🔥 重新理解 read、…...

深度学习中的张量 - 使用PyTorch进行广播和元素级操作
深度学习中的张量 - 使用PyTorch进行广播和元素级操作 元素级是什么意思? 元素级操作在神经网络编程中与张量的使用非常常见。让我们从一个元素级操作的定义开始这次讨论。 一个_元素级_操作是在两个张量之间进行的操作,它作用于各自张量中的相应元素…...
gitignore忽略已经提交过的
已经在.gitignore文件中添加了过滤规则来忽略bin和obj等文件夹,但这些文件夹仍然出现在提交中,可能是因为这些文件夹在添加.gitignore规则之前已经被提交到Git仓库中了。要解决这个问题,您需要从Git的索引中移除这些文件夹,并确保…...
h5使用video播放时关掉vant弹窗视频声音还在后台播放
现象: 1、点击遮罩弹窗关闭,弹窗的视频已经用v-if销毁,但是后台会自己从头开始播放视频声音。但是此时已经没有视频dom 2、定时器在打开弹窗后3秒自动关闭弹窗,则正常没有问题。 原来的代码: //页面 <a click&quo…...

Widows搭建sqli-labs
使用ms17_010渗透win7 ms17_010针对windows445端口(共享文件), 现有一台win7虚拟机IP 192.168.80.129, 开放445端口, 使用msf渗透该虚拟机 auxiliary 使用auxiliary判断目标主机是否适用smb17_010漏洞 这里发现80网段, 有一台主机适用 exploit 使用search ms17_010 type:expl…...
为AI聊天工具添加一个知识系统 之46 蒙板程序设计(第一版):Facet六边形【意识形态:操纵】
本文要点 要点 (原先标题冒号后只有 “Facet”后改为“Face六边形【意识形态】” ,是 事后想到的,本文并未明确提出。备忘在这里作为后续的“后期制作”的备忘) 前面讨论的(“之41 纯粹的思维”)中 说到,“意识”三…...
ASP.NET Core WebApi接口IP限流实践技术指南
在当今的Web开发中,接口的安全性和稳定性至关重要。面对恶意请求或频繁访问,我们需要采取有效的措施来保护我们的WebApi接口。IP限流是一种常见的技术手段,通过对来自同一IP地址的请求进行频率控制,可以有效地防止恶意攻击和过度消…...

css实现圆环展示百分比,根据值动态展示所占比例
代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...

Mybatis逆向工程,动态创建实体类、条件扩展类、Mapper接口、Mapper.xml映射文件
今天呢,博主的学习进度也是步入了Java Mybatis 框架,目前正在逐步杨帆旗航。 那么接下来就给大家出一期有关 Mybatis 逆向工程的教学,希望能对大家有所帮助,也特别欢迎大家指点不足之处,小生很乐意接受正确的建议&…...
java 实现excel文件转pdf | 无水印 | 无限制
文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...

智能在线客服平台:数字化时代企业连接用户的 AI 中枢
随着互联网技术的飞速发展,消费者期望能够随时随地与企业进行交流。在线客服平台作为连接企业与客户的重要桥梁,不仅优化了客户体验,还提升了企业的服务效率和市场竞争力。本文将探讨在线客服平台的重要性、技术进展、实际应用,并…...

selenium学习实战【Python爬虫】
selenium学习实战【Python爬虫】 文章目录 selenium学习实战【Python爬虫】一、声明二、学习目标三、安装依赖3.1 安装selenium库3.2 安装浏览器驱动3.2.1 查看Edge版本3.2.2 驱动安装 四、代码讲解4.1 配置浏览器4.2 加载更多4.3 寻找内容4.4 完整代码 五、报告文件爬取5.1 提…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...
怎么让Comfyui导出的图像不包含工作流信息,
为了数据安全,让Comfyui导出的图像不包含工作流信息,导出的图像就不会拖到comfyui中加载出来工作流。 ComfyUI的目录下node.py 直接移除 pnginfo(推荐) 在 save_images 方法中,删除或注释掉所有与 metadata …...
【学习笔记】erase 删除顺序迭代器后迭代器失效的解决方案
目录 使用 erase 返回值继续迭代使用索引进行遍历 我们知道类似 vector 的顺序迭代器被删除后,迭代器会失效,因为顺序迭代器在内存中是连续存储的,元素删除后,后续元素会前移。 但一些场景中,我们又需要在执行删除操作…...
适应性Java用于现代 API:REST、GraphQL 和事件驱动
在快速发展的软件开发领域,REST、GraphQL 和事件驱动架构等新的 API 标准对于构建可扩展、高效的系统至关重要。Java 在现代 API 方面以其在企业应用中的稳定性而闻名,不断适应这些现代范式的需求。随着不断发展的生态系统,Java 在现代 API 方…...
提升移动端网页调试效率:WebDebugX 与常见工具组合实践
在日常移动端开发中,网页调试始终是一个高频但又极具挑战的环节。尤其在面对 iOS 与 Android 的混合技术栈、各种设备差异化行为时,开发者迫切需要一套高效、可靠且跨平台的调试方案。过去,我们或多或少使用过 Chrome DevTools、Remote Debug…...