【深度学习实战】kaggle 自动驾驶的假场景分类
本次分享我在kaggle中参与竞赛的历程,这个版本是我的第一版,使用的是vgg。欢迎大家进行建议和交流。
概述
-
判断自动驾驶场景是真是假,训练神经网络或使用任何算法来分类驾驶场景的图像是真实的还是虚假的。
-
图像采用 RGB 格式并以 JPEG 格式压缩。
-
标签显示 (1) 真实和 (0) 虚假
-
二元分类
数据集描述
文件
train.csv - 训练集标签
Sample_submission.csv - 正确格式的示例提交文件
Train/- 训练图像
Test/ - 测试图像
模型思路
由于是要进行图像的二分类任务,因此考虑使用迁移学习,将vgg16中的卷积层和卷积层的参数完全迁移过来,不包括顶部的全连接层,自己设计适合该任务的头部结构,然后加以训练,绘制图像查看训练结果。
vgg16简介
VGG16 是由牛津大学视觉几何组(VGG)在2014年提出的卷积神经网络(CNN)。它由16个层组成,其中包含13个卷积层和3个全连接层。其特点是使用3x3的小卷积核和2x2的最大池化层,网络深度较深,有效提取图像特征。VGG16在图像分类任务中表现优异,尤其是在ImageNet挑战中取得了良好成绩。尽管计算量大、参数众多,但它因其简单而高效的结构,仍广泛应用于迁移学习和其他计算机视觉任务中。
源码+解析
- 第一步,导入所需的库。
import os
import cv2
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.applications.vgg16 import preprocess_input
- 加载文件
# 路径和文件
data_file = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/train.csv'
image_test = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Test/'
image_train = '/kaggle/input/cidaut-ai-fake-scene-classification-2024/Train/'# 加载标签数据
df = pd.read_csv(data_file)
df['image_path'] = df['image'].apply(lambda x: os.path.join(image_train, x))n_classes = df['label'].nunique()df.head() # 显示数据的前几行,检查路径和标签
输出
image label image_path
0 1.jpg editada /kaggle/input/cidaut-ai-fake-scene-classificat...
1 2.jpg real /kaggle/input/cidaut-ai-fake-scene-classificat...
2 3.jpg real /kaggle/input/cidaut-ai-fake-scene-classificat...
3 6.jpg editada /kaggle/input/cidaut-ai-fake-scene-classificat...
4 8.jpg real /kaggle/input/cidaut-ai-fake-scene-classificat...
原始train.csv文件只有前两列,image 和label 列,为了方便读取图像文件,新添加了一列image_path用来记录图像文件的具体路径。
# 初始化空列表 x 用于存储图像
x = []# 遍历每一行读取图像
for index, row in df.iterrows():image_path = row['image_path'] # 获取图像路径img = cv2.imread(image_path) # 使用 cv2 读取图像if img is not None:img_resized = cv2.resize(img, (256, 256)) # 调整图像尺寸为 (256, 256)x.append(img_resized) # 将读取的图像添加到列表 x 中else:print(f"图像 {row['image_path']} 读取失败") # 打印失败的路径# x 列表现在包含了所有读取的图像
print(f"总共有 {len(x)} 张图像被读取")
输出
总共有 720 张图像被读取
通过输出结果,可以看到图像被正确的读取了。并且将图像的大小调整为vgg所能用的256*256的尺寸,存放在变量x中。
- 第三步,进行数据处理
# 将图像转换为 NumPy 数组
x = np.array(x)# 标签映射并进行 one-hot 编码
y = df['label'].map({'real': 1, 'editada': 0})
y = np.array(y)
y = to_categorical(y, num_classes=2) # 二分类# 分割训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)# 检查转换后的结果
print(f"x_train.shape: {x_train.shape}")
print(f"y_train.shape: {y_train.shape}")
print(f"x_test.shape: {x_test.shape}")
print(f"y_test.shape: {y_test.shape}")
输出
x_train.shape: (576, 256, 256, 3)
y_train.shape: (576, 2)
x_test.shape: (144, 256, 256, 3)
y_test.shape: (144, 2)
这里是为了将原始的图像转换为numpy数组,并且将标签进行独热编码,(对分类的标签一定要进行独热编码,转换为矩阵形式),并且切分数据集。
- 第四步,设计模型结构
from tensorflow.keras.regularizers import l2
# 加载预训练的VGG16卷积基(不包括顶部的全连接层)
vgg16_model = VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))# 冻结VGG16的卷积层
for layer in vgg16_model.layers:layer.trainable = False# 创建一个新的模型
model_fine_tuning = Sequential()# 将VGG16的卷积基添加到新模型中
model_fine_tuning.add(vgg16_model) # 添加VGG16卷积基
model_fine_tuning.add(Flatten()) # 将卷积特征图展平# 添加新的全连接层并进行正则化
model_fine_tuning.add(Dense(512, activation='relu', kernel_regularizer=l2(0.01))) # L2正则化
model_fine_tuning.add(Dropout(0.3)) # Dropout层,减少过拟合
model_fine_tuning.add(Dense(256, activation='relu', kernel_regularizer=l2(0.01))) # 较小的全连接层
model_fine_tuning.add(Dropout(0.3) ) # 再次使用Dropout层# 输出层
model_fine_tuning.add(Dense(2, activation='softmax')) # 对于二分类问题,使用softmax# 查看模型架构
model_fine_tuning.summary()
输出:
Layer (type) | Output Shape | Param # |
---|---|---|
vgg16 (Functional) | (None, 8, 8, 512) | 14,714,688 |
flatten (Flatten) | (None, 32768) | 0 |
dense (Dense) | (None, 512) | 16,777,728 |
dropout (Dropout) | (None, 512) | 0 |
dense_1 (Dense) | (None, 256) | 131,328 |
dropout_1 (Dropout) | (None, 256) | 0 |
dense_2 (Dense) | (None, 2) | 514 |
这里实现了一个基于预训练VGG16模型的迁移学习框架,用于图像分类任务。首先,加载了预训练的VGG16卷积基(不包括全连接层),并通过设置include_top=False
来只使用卷积部分,从而利用其在ImageNet数据集上学到的特征。接着,冻结VGG16的卷积层,即通过将trainable
属性设为False
,使得这些层在训练过程中不进行更新。接下来,创建了一个新的Sequential
模型,并将VGG16的卷积基添加进去,随后使用Flatten
层将卷积特征图展平,为全连接层准备输入。为了增加模型的表达能力,添加了两个全连接层,每个层都应用了ReLU
激活函数,并使用L2
正则化来防止过拟合。为了进一步减少过拟合,模型还在每个全连接层后添加了Dropout
层,丢弃30%的神经元。最后,输出层是一个具有两个神经元的全连接层,采用softmax
激活函数,用于处理二分类问题。model_fine_tuning.summary()
方法输出模型架构,帮助查看各层的结构和参数。通过这种方式,模型能够利用VGG16的预训练卷积基进行特征提取,并通过新添加的全连接层进行分类。
- 第五步,编译并训练模型
# 编译模型
model_fine_tuning.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])datagen = ImageDataGenerator(rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest',preprocessing_function=preprocess_input) # 使用VGG16的预处理函数# 对原始图像进行增强,并进行训练
history = model_fine_tuning.fit(datagen.flow(x_train, y_train, batch_size=32),epochs=20,validation_data=(x_test, y_test),callbacks=[ModelCheckpoint('best_model.keras', save_best_only=True),EarlyStopping(patience=5)])
这里主要完成了对已经构建的模型(model_fine_tuning)
的编译与训练过程。
- 首先,使用
compile()
方法对模型进行编译,指定损失函数为binary_crossentropy
,适用于二分类问题,同时选择Adam优化器,这是一种自适应学习率的优化算法,能够有效提升训练性能。在编译时,还通过metrics=['accuracy']
设置了准确率作为评估指标。 - 接着,创建了一个
ImageDataGenerator
对象用于数据增强,它包含多种图像变换方式,如旋转、平移、剪切、缩放、水平翻转等,这些操作可以增加数据多样性,减少过拟合,提升模型的泛化能力。 - 此外,
preprocessing_function=preprocess_input
使用了VGG16
预训练模型的标准预处理函数,确保输入图像的像素范围符合VGG16的训练要求。 - 随后,通过
fit()
方法开始训练模型,训练数据通过datagen.flow()
进行增强和批量生成,训练将在20个周期(epochs
)内进行。在训练过程中,还设置了两个回调函数:ModelCheckpoint
,用于保存最好的模型权重文件(best_model.keras
),并且只保存验证集上表现最好的模型; EarlyStopping
,用于在验证集准确率不再提升时提前停止训练,patience=5
表示如果5个周期内没有改进,则停止训练。这样,通过数据增强和回调函数的配合,能够有效提高训练的效果和模型的稳定性。
到这里,整个部分就基本完成了。
- 绘制损失和准确率图像
import matplotlib.pyplot as plt# 获取训练过程中的损失和准确率数据
history_dict = history.history
loss = history_dict['loss']
accuracy = history_dict['accuracy']
val_loss = history_dict['val_loss']
val_accuracy = history_dict['val_accuracy']# 绘制损失图
plt.figure(figsize=(12, 6))# 损失图
plt.subplot(1, 2, 1)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()# 准确率图
plt.subplot(1, 2, 2)
plt.plot(accuracy, label='Training Accuracy')
plt.plot(val_accuracy, label='Validation Accuracy')
plt.title('Accuracy over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()# 展示图像
plt.tight_layout()
plt.show()
数据文件已经上传,感兴趣的小伙伴可以下载后自己尝试。
相关文章:

【深度学习实战】kaggle 自动驾驶的假场景分类
本次分享我在kaggle中参与竞赛的历程,这个版本是我的第一版,使用的是vgg。欢迎大家进行建议和交流。 概述 判断自动驾驶场景是真是假,训练神经网络或使用任何算法来分类驾驶场景的图像是真实的还是虚假的。 图像采用 RGB 格式并以 JPEG 格式…...

Spring Boot 和微服务:快速入门指南
💖 欢迎来到我的博客! 非常高兴能在这里与您相遇。在这里,您不仅能获得有趣的技术分享,还能感受到轻松愉快的氛围。无论您是编程新手,还是资深开发者,都能在这里找到属于您的知识宝藏,学习和成长…...

qt QPainter setViewport setWindow viewport window
使用qt版本5.15.2 引入viewport和window目的是用于实现QPainter画出来的内容随着窗体伸缩与不伸缩两种情况,以及让QPainter在widget上指定的区域(viewport)进行绘制/渲染(分别对应下方demo1,demo2,demo3)。 setViewpo…...

网络安全面试题汇总(个人经验)
1.谈一下SQL主从备份原理? 答:主将数据变更写入自己的二进制log,从主动去主那里去拉二进制log并写入自己的二进制log,从而自己数据库依据二进制log内容做相应变更。主写从读 2.linux系统中的计划任务crontab配置文件中的五个星星分别代表什么ÿ…...

【网络云SRE运维开发】2025第3周-每日【2025/01/14】小测-【第13章ospf路由协议】理论和实操
文章目录 选择题(10道)理论题(5道)实操题(5道) 【网络云SRE运维开发】2025第3周-每日【2025/01/14】小测-【第12章ospf路由协议】理论和实操 选择题(10道) 在OSPF协议中,…...

FreeType 介绍及 C# 示例
FreeType 是一个开源的字体渲染引擎,用于将字体文件(如 TrueType、OpenType、Type 1 等)转换为位图或矢量图形。它广泛应用于操作系统、图形库、游戏引擎等领域,支持高质量的字体渲染和复杂的文本布局。 FreeType 的核心功能 字体…...

金融项目实战 04|JMeter实现自动化脚本接口测试及持续集成
目录 一、⾃动化测试理论 二、自动化脚本 1、添加断言 1️⃣注册、登录 2️⃣认证、充值、开户、投资 2、可重复执行:清除测试数据脚本按指定顺序执行 1️⃣如何可以做到可重复执⾏? 2️⃣清除测试数据:连接数据库setup线程组 ①明确…...

Linux网络知识——路由表
路由表 1 定义与作用 Linux路由表是一个内核数据结构,用于描述Linux主机与其他网络设备之间的路径,以及如何将数据包从源地址路由到目标地址。路由表的主要作用是指导数据包在网络中的传输路径,确保数据包能够准确、高效地到达目标地址。 …...

浅谈云计算14 | 云存储技术
云存储技术 一、云计算网络存储技术基础1.1 网络存储的基本概念1.2云存储系统结构模型1.1.1 存储层1.1.2 基础管理层1.1.3 应用接口层1.1.4 访问层 1.2 网络存储技术分类 二、云计算网络存储技术特点2.1 超大规模与高可扩展性2.1.1 存储规模优势2.1.2 动态扩展机制 2.2 高可用性…...

AI 编程工具—Cursor进阶使用 阅读开源项目
AI 编程工具—Cursor进阶使用 阅读开源项目 首先我们打开一个最近很火的项目browser-use ,直接从github 上克隆即可 索引整个代码库 这里我们使用@Codebase 这个选项会索引这个代码库,然后我们再选上这个项目的README.md 文件开始提问 @Codebase @README.md 这个项目是用…...

使用 WPF 和 C# 将纹理应用于三角形
此示例展示了如何将纹理应用于三角形,以使场景比覆盖纯色的场景更逼真。以下是为三角形添加纹理的基本步骤。 创建一个MeshGeometry3D对象。像往常一样定义三角形的点和法线。通过向网格的TextureCoordinates集合添加值来设置三角形的纹理坐标。创建一个使用想要显示的纹理的 …...

Elasticsearch搜索引擎(二)
RestClient 基础 前言一、RestAPI1. 初始化 *RestClient*2. 创建索引库3. 删除索引库4. 判断索引库是否存在 二、RestClient操作文档1.新增文档2.查询文档3. 删除文档4. 修改文档5. 批量导入文档 前言 ES官方提供了各种不同语言的客户端用来操作ES,这些客户端的本质…...

unity学习17:unity里的旋转学习,欧拉角,四元数等
目录 1 三维空间里的旋转与欧拉角,四元数 1.1 欧拉角比较符合直观 1.2 四元数 1.3 下面是欧拉角和四元数的一些参考文章 2 关于旋转的这些知识点 2.1 使用euler欧拉角旋转 2.2 使用quaternion四元数,w,x,y,z 2.3 使用quaternion四元数,类 Vector3.zero 这种…...

走出实验室的人形机器人,将复刻ChatGPT之路?
1月7日,在2025年CES电子展现场,黄仁勋不仅展示了他全新的皮衣和采用Blackwell架构的RTX 50系列显卡,更进一步展现了他对于机器人技术领域,特别是人形机器人和通用机器人技术的笃信。黄仁勋认为机器人即将迎来ChatGPT般的突破&…...

如何使用wireshark 解密TLS-SSL报文
目录 前言 原理 操作 前言 现在网站都是https 或者 很多站点都支持 http2。这些站点为了保证数据的安全都通过TLS/SSL 加密过,用wireshark 并不能很好的去解析报文,我们就需要用wireshark去解密这些报文。我主要讲解下mac 在 chrome 怎么配置的&…...

电脑有两张网卡,如何实现同时访问外网和内网?
要是想让一台电脑用两张网卡,既能访问外网又能访问内网,那可以通过设置网络路由还有网卡的 IP 地址来达成。 检查一下网卡的连接 得保证电脑的两张网卡分别连到外网和内网的网络设备上,像路由器或者交换机啥的。 给网卡配上不一样的 IP 地…...

定义:除了Vue内置指令以外的其他 v-开头的指令(需要程序员自行扩展定义)作用:自己定义的指令, 可以封装一些 dom 操作, 扩展
1.自定义指令(directives) 1.用法 定义:除了Vue内置指令以外的其他 v-开头的指令(需要程序员自行扩展定义)作用:自己定义的指令, 可以封装一些 dom 操作, 扩展额外功能 语法: ① 局部注册 ●inserted:被绑…...

SpringBoot错误码国际化
先看测试效果: 文件结构 1.中文和英文的错误消息配置 package com.ldj.mybatisflex.common;import lombok.Getter;/*** User: ldj* Date: 2025/1/12* Time: 17:50* Description: 异常消息枚举*/ Getter public enum ExceptionEnum {//# code命名规则:模…...

LeetCode 3066.超过阈值的最少操作数 II:模拟 - 原地建堆O(1)空间 / 优先队列O(n)空间
【LetMeFly】3066.超过阈值的最少操作数 II:模拟 - 原地建堆O(1)空间 / 优先队列O(n)空间 力扣题目链接:https://leetcode.cn/problems/minimum-operations-to-exceed-threshold-value-ii/ 给你一个下标从 0 开始的整数数组 nums 和一个整数 k 。 一次…...

深度学习中的模块复用原则(定义一次还是多次)
文章目录 1. 模块复用的核心原则(1)模块是否有**可学习参数**(2)模块是否有**内部状态**(3)模块的功能需求是否一致 2. 必须单独定义的模块(1)nn.Linear(全连接层&#x…...

Mac——Cpolar内网穿透实战
摘要 本文介绍了在Mac系统上实现内网穿透的方法,通过打开远程登录、局域网内测试SSH远程连接,以及利用cpolar工具实现公网SSH远程连接MacOS的步骤。包括安装配置homebrew、安装cpolar服务、获取SSH隧道公网地址及测试公网连接等关键环节。 1. MacOS打开…...

安全测评主要标准
大家读完觉得有帮助记得关注和点赞!!! 安全测评的主要标准包括多个国际和国内的标准,这些标准为信息系统和产品的安全评估提供了基础和指导。 一、安全测评的主要标准 1.1、国际标准 可信计算机系统评估准则(TC…...

qBittorent访问webui时提示unauthorized解决方法
现象描述 QNAP使用Container Station运行容器,使用Docker封装qBittorrent时,访问IP:PORT的方式后无法访问到webui,而是提示unauthorized,如图: 原因分析 此时通常是由于设备IP与qBittorrent的ip地址不在同一个网段导致…...

504 Gateway Timeout:网关超时解决方法
一、什么是 504Gateway Timeout? 1. 错误定义 504 Gateway Timeout 是 HTTP 状态码的一种,表示网关或代理服务器在等待上游服务器响应时超时。通俗来说,这是服务器之间“对话失败”导致的。 2. 常见触发场景 Nginx 超时:反向代…...

Vue 实现当前页面刷新的几种方法
以下是 Vue 中实现当前页面刷新的几种方法: 方法一:使用 $router.go(0) 方法 通过Vue Router进行重新导航,可以实现页面的局部刷新,而不丢失全局状态。具体实现方式有两种: 实现代码: <template&g…...

MCP Server开发的入门教程(python和pip)
使用python技术栈开发的简单mcp server 需要安装 MCP server的需要使用python-sdk,python需要 3.10,安装如下 pip install mcpPS: MCP官方使用的是uv包管理工具,我平时使用pip比较多,所以文中以pip为主。因为mcp的一些依赖包版本并不是最新的,所以最好弄一个干净的环境…...

手撕Transformer -- Day7 -- Decoder
手撕Transformer – Day7 – Decoder Transformer 网络结构图 目录 手撕Transformer -- Day7 -- DecoderTransformer 网络结构图Decoder 代码Part1 库函数Part2 实现一个解码器Decoder,作为一个类Part3 测试 参考 Transformer 网络结构 Decoder 代码 Part1 库函数…...

C#异步和多线程,Thread,Task和async/await关键字--12
目录 一.多线程和异步的区别 1.多线程 2.异步编程 多线程和异步的区别 二.Thread,Task和async/await关键字的区别 1.Thread 2.Task 3.async/await 三.Thread,Task和async/await关键字的详细对比 1.Thread和Task的详细对比 2.Task 与 async/await 的配合使用 3. asy…...

使用分割 Mask 和 K-means 聚类获取天空的颜色
引言 在计算机视觉领域,获取天空的颜色是一个常见任务,广泛应用于天气分析、环境感知和图像增强等场景。本篇博客将介绍如何通过已知的天空区域 Mask 提取天空像素,并使用 K-means 聚类分析天空颜色,最终根据颜色占比查表得到主导…...

145.《redis原生超详细使用》
文章目录 什么是redisredis 安装启动redis数据类型redis key操作key 的增key 的查key 的改key 的删key 是否存在key 查看所有key 「设置」过期时间key 「查看」过期时间key 「移除」过期时间key 「查看」数据类型key 「匹配」符合条件的keykey 「移动」到其他数据库 redis数据类…...