【机器学习实战】kaggle 欺诈检测---使用生成对抗网络(GAN)解决欺诈数据中正负样本极度不平衡问题
【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题
https://blog.csdn.net/2302_79308082/article/details/145177242
本篇文章是基于上次文章中提到的对抗生成网络,通过对抗生成网络生成少数类样本,平衡欺诈数据中正类样本极少的问题。
本人主页:机器学习小小白
机器学习专栏:机器学习实战
PyTorch入门专栏:PyTorch入门
深度学习实战:深度学习
ok,话不多说,我们进入正题吧
1. 引言
生成对抗网络(Generative Adversarial Networks,简称GAN)是由Ian Goodfellow等人于2014年提出的一种深度学习模型。它在计算机视觉、自然语言处理、音频生成等领域得到了广泛应用。GAN的核心思想是通过两个神经网络之间的博弈关系来生成新的、仿真的数据。自从GAN提出以来,它已经成为生成模型领域的突破性进展,深刻改变了生成式模型的研究和应用。
2. GAN的基本原理
生成对抗网络的结构包括两个主要部分:生成器(Generator)和判别器(Discriminator)。这两个网络分别充当“对手”,并在训练过程中互相博弈:
-
生成器(Generator):该网络的目的是通过学习数据分布来生成尽可能接近真实数据的虚假样本。生成器从一个随机的噪声(通常是高维的向量)出发,逐步生成样本。
-
判别器(Discriminator):该网络的任务是判断一个样本是真实的(来自训练数据)还是虚假的(来自生成器)。判别器输出一个概率值,表示输入样本为真实数据的概率。
3. GAN的训练过程
GAN的训练过程是一个“博弈”过程,生成器和判别器不断互相对抗,从而提升各自的性能。这个过程可以通过以下的数学公式来表示:
- 判别器的目标:判别器的目标是最大化其对于真实数据的判断概率(即预测为1的概率),同时最小化对生成数据的错误分类(即预测为0的概率)。可以通过以下的交叉熵损失函数表示:
其中:
-
是从真实数据分布中采样的数据。
-
是生成器生成的样本,
是从潜在空间中采样的噪声。
-
是判别器对样本
的判别输出,表示其为真实数据的概率。
-
生成器的目标:生成器的目标是使判别器无法区分生成数据与真实数据,因此它通过最大化判别器对生成数据为真实的概率来进行训练:
-
-
其中:
是生成器生成的虚假样本,
是判别器对生成样本的输出,表示其为真实数据的概率。
在训练过程中,生成器和判别器会交替优化这两个损失函数。理想的结果是生成器能够生成与真实数据分布相似的样本,而判别器则无法有效地区分生成数据与真实数据。
4. GAN的应用
GAN具有强大的生成能力,广泛应用于多个领域,以下是一些典型的应用场景:
-
图像生成:GAN可以用于生成高度逼真的图像,如人脸、风景或艺术作品。典型的例子包括DeepArt和StyleGAN,后者能够生成几乎无法与真实人脸区分的图像。
-
图像到图像的转换:例如,利用GAN进行图像风格转换(如将照片转化为油画风格)、超分辨率重建(如提高图像的分辨率)、图像修复(如填补丢失部分)等任务。
-
文本生成:结合自然语言处理技术,GAN也可用于生成文本数据,如诗歌、故事生成等,尤其是文本生成和对话系统中的对抗训练。
-
音频生成:GAN被广泛应用于音频生成,如音乐生成、语音合成等。
-
数据增强:GAN可以用于数据增强,特别是在医疗图像领域,生成具有一定变异的图像样本,以增强训练数据集。
-
模型训练中的对抗样本生成:GAN可以生成对抗样本,即通过对训练数据进行微小扰动,生成能够误导模型的样本,这对提升模型的鲁棒性非常重要。
5. GAN的变种
GAN作为一种框架,已经发展出了多种变种,以满足不同应用的需求。以下是几种常见的GAN变种:
-
CGAN(Conditional GAN):在生成器和判别器中都加入了条件变量,使得生成的样本可以根据某些条件(如标签信息)进行控制。
-
WGAN(Wasserstein GAN):解决了传统GAN在训练过程中可能出现的梯度消失和模式崩溃问题。WGAN使用了Wasserstein距离作为生成器和判别器的损失函数。
-
DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来构建生成器和判别器,增强了GAN在图像生成任务中的表现。
-
CycleGAN:用于无监督学习场景,特别是在图像到图像的转换中,例如将一张照片转换成另一种风格(如马到斑马转换)。
6. 使用生成对抗网络(GAN)生成欺诈数据中少数类数据
1. 数据预处理与特征提取
import pandas as pd
import numpy as nptrain_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/train.csv')
test_df = pd.read_csv('/kaggle/input/credit-card-fraud-prediction/test.csv')def time_feature(df):df['Time'] = pd.to_datetime(df['Time'], unit='s') # 将时间戳转为 datetime 格式# 提取时间特征df['hour'] = df['Time'].dt.hourdf['minute'] = df['Time'].dt.minute return df train_df = time_feature(train_df)
test_df = time_feature(test_df)
在欺诈检测任务中,时间特征(如交易发生的小时和分钟)通常是重要的,因为欺诈交易往往具有不同的时间模式。例如,欺诈交易可能集中在某些特定的时间段。
- 这里我们通过
pd.to_datetime()
将Time
列从Unix时间戳格式转换为日期时间格式。然后,我们提取了小时和分钟作为新的特征,用于训练模型。
train_feature = train_df.drop(columns=['id','IsFraud','Time'])
test_feature = test_df.drop(columns=['id','Time'])label = train_df['IsFraud']
train_feature
是用于训练的特征数据,删除了 id
, IsFraud
和 Time
列。IsFraud
是标签列,表示交易是否为欺诈交易;而 id
和 Time
列不包含有用的特征信息,因此可以去掉。
2. 标准化数据
from sklearn.preprocessing import StandardScaler# 标准化特征数据
scaler = StandardScaler()
train_feature_scaled = scaler.fit_transform(train_feature)
-
标准化(Standardization)是机器学习中常用的预处理步骤。它通过减去均值并除以标准差,使特征数据具有零均值和单位方差。标准化能够加速模型的收敛过程,尤其是在使用像神经网络这样的梯度优化模型时。
-
这里使用
StandardScaler
来对训练数据进行标准化,以确保所有特征在同一个量级。
3. 生成器与判别器的构建
生成器(Generator)
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Inputdef build_generator(latent_dim, input_dim):model = Sequential()model.add(Input(shape=(latent_dim,))) # 使用 Input 层来指定输入维度model.add(Dense(256))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(input_dim, activation='tanh')) # 输出层与原数据同维度return model
生成器(Generator)是GAN的核心部分,它通过接收随机噪声向量(潜在空间中的点),然后经过一系列的全连接层和激活函数,生成与原始数据分布相似的虚假数据。
- 在此,我们使用了
LeakyReLU
激活函数,它允许梯度通过负半轴流动,解决了传统ReLU可能出现的“死神经元”问题。BatchNormalization
用于加速网络的训练,并帮助改善模型的稳定性。
判别器(Discriminator)
def build_discriminator(input_dim):model = Sequential()model.add(Input(shape=(input_dim,))) # 使用 Input 层来指定输入维度model.add(Dense(1024))model.add(LeakyReLU(0.2))model.add(Dense(512))model.add(LeakyReLU(0.2))model.add(Dense(256))model.add(LeakyReLU(0.2))model.add(Dense(1, activation='sigmoid')) # 输出真假判定return model
判别器(Discriminator)的任务是判断输入数据是真实的还是由生成器生成的。它是一个二分类模型,输出是一个概率值,表示输入数据为真实的概率。
- 这里使用
sigmoid
激活函数,输出一个概率值。判别器学习将真实数据和生成数据区分开来。
4. GAN模型的组合与训练
def build_gan(generator, discriminator):discriminator.trainable = False # 在训练GAN时冻结判别器model = Sequential()model.add(generator)model.add(discriminator)return model# 定义优化器
optimizer = Adam()# 定义输入维度和潜在维度
latent_dim = 100 # 随机噪声的维度
input_dim = 31 # 输入数据的维度,例如欺诈检测数据的特征数# 创建并编译模型
generator = build_generator(latent_dim, input_dim)
discriminator = build_discriminator(input_dim)
gan = build_gan(generator, discriminator)# 编译判别器和GAN模型
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
-
生成对抗训练(Adversarial Training)是GAN的关键。生成器和判别器在一个博弈过程中互相优化。在训练过程中,生成器通过“欺骗”判别器来优化其生成数据的能力,而判别器则不断学习区分真实和生成数据。
-
在训练过程中,我们冻结判别器的参数,只训练生成器,这样可以避免在训练生成器时更新判别器的权重。
5. GAN训练函数
def train_gan(generator, discriminator, gan, fraud_data_scaled, epochs=10000, batch_size=64):valid = np.ones((batch_size, 1)) # 真数据标签fake = np.zeros((batch_size, 1)) # 假数据标签for epoch in range(epochs):# 随机选择真实欺诈数据idx = np.random.randint(0, fraud_data_scaled.shape[0], batch_size)real_data = fraud_data_scaled[idx]# 生成虚拟数据noise = np.random.normal(0, 1, (batch_size, latent_dim))generated_data = generator.predict(noise)# 训练判别器d_loss_real = discriminator.train_on_batch(real_data, valid)d_loss_fake = discriminator.train_on_batch(generated_data, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, latent_dim))g_loss = gan.train_on_batch(noise, valid)# 输出训练过程的损失if epoch % 1000 == 0:print(f'{epoch}/{epochs} [D loss: {d_loss[0]}] [G loss: {g_loss}]')
-
训练过程:在每个训练周期中,首先更新判别器的权重(通过训练它区分真实数据和生成数据),然后训练生成器(通过训练它欺骗判别器)。
-
损失函数:我们使用了
binary_crossentropy
损失函数,它用于二分类任务。在判别器的训练中,我们分别计算真实数据和生成数据的损失,然后平均得到判别器的总损失。生成器的损失则是通过GAN模型进行计算的。
6. 生成虚拟数据
def generate_fake_data(generator, num_samples):noise = np.random.normal(0, 1, (num_samples, latent_dim)) # 随机噪声generated_data = generator.predict(noise) # 生成虚拟数据# 将生成的数据转换回原始空间generated_data_original = scaler.inverse_transform(generated_data)# 获取原始负样本数据的列名(去除 'id', 'IsFraud', 'Time' 列)feature_columns = [col for col in train_df.columns if col not in ['id', 'IsFraud', 'Time']]# 将生成的数据与原始负样本数据(即非欺诈数据)结合,作为新的训练数据augmented_data = np.concatenate([train_df[train_df['IsFraud'] == 0].drop(columns=['id', 'IsFraud', 'Time']),generated_data_original], axis=0)augmented_label = np.concatenate([np.zeros(train_df[train_df['IsFraud'] == 0].shape[0]), np.ones(generated_data_original.shape[0])], axis=0)# 创建包含生成数据和标签的 DataFrameaugmented_df = pd.DataFrame(augmented_data, columns=feature_columns)augmented_df['IsFraud'] = augmented_labelreturn augmented_df
在这个函数中,我们使用训练好的生成器来生成新的虚拟欺诈数据,并将它们与真实的非欺诈数据结合,以增强数据集。然后,我们通过逆标准化将生成的数据转换回原始数据空间。
本次例子为了缩短训练时间,只生成了100条虚拟的正样本数据。
相关文章:
【机器学习实战】kaggle 欺诈检测---使用生成对抗网络(GAN)解决欺诈数据中正负样本极度不平衡问题
【机器学习实战】kaggle 欺诈检测---如何解决欺诈数据中正负样本极度不平衡问题https://blog.csdn.net/2302_79308082/article/details/145177242 本篇文章是基于上次文章中提到的对抗生成网络,通过对抗生成网络生成少数类样本,平衡欺诈数据中正类样本极…...
android wifi framework与wpa_supplicant的交互
android frmework直接与wpa_supplicant进行交互,使用aidl或者hidl 二、事件 framework注册事件的地方: packages/modules/Wifi/service/java/com/android/server/wifi/SupplicantStaIfaceCallbackImpl.java class SupplicantStaIfaceCallbackImpl exte…...

初学stm32 --- flash模仿eeprom
目录 STM32内部FLASH简介 内部FLASH构成(F1) FLASH读写过程(F1) 闪存的读取 闪存的写入 内部FLASH构成(F4 / F7 / H7) FLASH读写过程(F4 / F7 / H7) 闪存的读取 闪存的写入 …...

使用C语言实现栈的插入、删除和排序操作
栈是一种后进先出(LIFO, Last In First Out)的数据结构,这意味着最后插入的元素最先被删除。在C语言中,我们可以通过数组或链表来实现栈。本文将使用数组来实现一个简单的栈,并提供插入(push)、删除(pop)以及排序(这里采用一种简单的排序方法,例如冒泡排序)的操作示…...

C语言程序环境和预处理详解
本章重点: 程序的翻译环境 程序的执行环境 详解:C语言程序的编译链接 预定义符号介绍 预处理指令 #define 宏和函数的对比 预处理操作符#和##的介绍 命令定义 预处理指令 #include 预处理指令 #undef 条件编译 程序的翻译环境和执行环…...

基于机器学习随机森林算法的个人职业预测研究
1.背景调研 随着信息技术的飞速发展,特别是大数据和云计算技术的广泛应用,各行各业都积累了大量的数据。这些数据中蕴含着丰富的信息和模式,为利用机器学习进行职业预测提供了可能。机器学习算法的不断进步,如深度学习、强化学习等…...

三种文本相似计算方法:规则、向量与大模型裁判
文本相似计算 项目背景 目前有众多工作需要评估字符串之间的相似(相关)程度: 比如,RAG 智能问答系统文本召回阶段需要计算用户文本与文本库内文本的相似分数,返回前TopK个候选文本。 在评估大模型生成的文本阶段,也需要评估…...
Python语言的计算机基础
Python语言的计算机基础 绪论 在当今信息技术飞速发展的时代,编程已经成为了一种必备技能。Python凭借其简洁、易读和强大的功能,逐渐成为初学者学习编程的首选语言。本文将以Python语言为基础,探讨计算机科学的基本概念,并帮助…...

Dify应用-工作流
目录 DIFY 工作流参考 DIFY 工作流 2025-1-15 老规矩感谢参考文章的作者,避免走弯路。 2025-1-15 方便容易上手 在dify的一个桌面上,添加多个节点来完成一个任务。 每个工作流必须有一个开始和结束节点。 节点之间用线连接即可。 每个节点可以有输入和输出 输出类型有,字符串,…...
02.02、返回倒数第 k 个节点
02.02、[简单] 返回倒数第 k 个节点 1、题目描述 实现一种算法,找出单向链表中倒数第 k 个节点。返回该节点的值。 2、题解思路 本题的关键在于使用双指针法,通过两个指针(fast 和 slow),让 fast 指针比 slow 指针…...

Linux手写FrameBuffer任意引脚驱动spi屏幕
一、硬件设备 开发板:香橙派 5Plus,cpu:RK3588,带有 40pin 外接引脚。 屏幕:SPI 协议 0.96 寸 OLED。 二、需求 主要是想给板子增加一个可视化的监视器,并且主页面可调。 平时跑个模型或者服务,…...

怎么修复损坏的U盘?而且不用格式化的方式!
当你插入U盘时,若电脑弹出“需要格式化才能使用”提示,且无法打开或读取其中的数据,说明U盘极有可能已经损坏。除此之外,若电脑在连接U盘后显示以下信息,也可能意味着U盘出现问题,需要修复损坏的U盘&#x…...
语音技术在播客领域的应用(2)
播客是以语音为主,各种基于AI 的语音技术在播客领域十分重要。 语音转文本 Whisper Whisper 是OpenAI 推出的开源语音辨识工具,可以把音档转成文字,支援超过50 种语言。这款工具是基于68 万小时的训练资料,其中包含11.7 万小时的…...

【Linux】应用层自定义协议与序列化
🌈 个人主页:Zfox_ 🔥 系列专栏:Linux 目录 一:🔥 应用层 🦋 再谈 "协议"🦋 网络版计算器🦋 序列化 和 反序列化 二:🔥 重新理解 read、…...

深度学习中的张量 - 使用PyTorch进行广播和元素级操作
深度学习中的张量 - 使用PyTorch进行广播和元素级操作 元素级是什么意思? 元素级操作在神经网络编程中与张量的使用非常常见。让我们从一个元素级操作的定义开始这次讨论。 一个_元素级_操作是在两个张量之间进行的操作,它作用于各自张量中的相应元素…...
gitignore忽略已经提交过的
已经在.gitignore文件中添加了过滤规则来忽略bin和obj等文件夹,但这些文件夹仍然出现在提交中,可能是因为这些文件夹在添加.gitignore规则之前已经被提交到Git仓库中了。要解决这个问题,您需要从Git的索引中移除这些文件夹,并确保…...
h5使用video播放时关掉vant弹窗视频声音还在后台播放
现象: 1、点击遮罩弹窗关闭,弹窗的视频已经用v-if销毁,但是后台会自己从头开始播放视频声音。但是此时已经没有视频dom 2、定时器在打开弹窗后3秒自动关闭弹窗,则正常没有问题。 原来的代码: //页面 <a click&quo…...

Widows搭建sqli-labs
使用ms17_010渗透win7 ms17_010针对windows445端口(共享文件), 现有一台win7虚拟机IP 192.168.80.129, 开放445端口, 使用msf渗透该虚拟机 auxiliary 使用auxiliary判断目标主机是否适用smb17_010漏洞 这里发现80网段, 有一台主机适用 exploit 使用search ms17_010 type:expl…...
为AI聊天工具添加一个知识系统 之46 蒙板程序设计(第一版):Facet六边形【意识形态:操纵】
本文要点 要点 (原先标题冒号后只有 “Facet”后改为“Face六边形【意识形态】” ,是 事后想到的,本文并未明确提出。备忘在这里作为后续的“后期制作”的备忘) 前面讨论的(“之41 纯粹的思维”)中 说到,“意识”三…...
ASP.NET Core WebApi接口IP限流实践技术指南
在当今的Web开发中,接口的安全性和稳定性至关重要。面对恶意请求或频繁访问,我们需要采取有效的措施来保护我们的WebApi接口。IP限流是一种常见的技术手段,通过对来自同一IP地址的请求进行频率控制,可以有效地防止恶意攻击和过度消…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

Vue2 第一节_Vue2上手_插值表达式{{}}_访问数据和修改数据_Vue开发者工具
文章目录 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染2. 插值表达式{{}}3. 访问数据和修改数据4. vue响应式5. Vue开发者工具--方便调试 1.Vue2上手-如何创建一个Vue实例,进行初始化渲染 准备容器引包创建Vue实例 new Vue()指定配置项 ->渲染数据 准备一个容器,例如: …...

C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...
3403. 从盒子中找出字典序最大的字符串 I
3403. 从盒子中找出字典序最大的字符串 I 题目链接:3403. 从盒子中找出字典序最大的字符串 I 代码如下: class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...
安卓基础(aar)
重新设置java21的环境,临时设置 $env:JAVA_HOME "D:\Android Studio\jbr" 查看当前环境变量 JAVA_HOME 的值 echo $env:JAVA_HOME 构建ARR文件 ./gradlew :private-lib:assembleRelease 目录是这样的: MyApp/ ├── app/ …...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行
项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战,克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

基于SpringBoot在线拍卖系统的设计和实现
摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 在线拍卖系统,主要的模块包括管理员;首页、个人中心、用户管理、商品类型管理、拍卖商品管理、历史竞拍管理、竞拍订单…...

[免费]微信小程序问卷调查系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】
大家好,我是java1234_小锋老师,看到一个不错的微信小程序问卷调查系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】,分享下哈。 项目视频演示 【免费】微信小程序问卷调查系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…...
MySQL JOIN 表过多的优化思路
当 MySQL 查询涉及大量表 JOIN 时,性能会显著下降。以下是优化思路和简易实现方法: 一、核心优化思路 减少 JOIN 数量 数据冗余:添加必要的冗余字段(如订单表直接存储用户名)合并表:将频繁关联的小表合并成…...

免费数学几何作图web平台
光锐软件免费数学工具,maths,数学制图,数学作图,几何作图,几何,AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...